赞
踩
首先来说LangChain是什么?不了解的可以点击下面的链接来查看下。
然后在介绍一下星火认知大模型相关:
讯飞星火认知大模型
感兴趣的小伙伴可以了解一下,国内比较成熟的类GPT(我自己定义的,也不知道对不对)模型。
说一下大概需求,首先我是要用到功能是文章摘要,之前接入的是OpenAI的api接口(langchain中已经封装好了相关内容),其实只对模型的好用程度来说OpenAI确实要相较于市面上其他的模型都要更智能一点,哪怕是对中文来说,而且因为是自己调试,都没用到GPT-4,仅是3.5系列模型都更加优秀一些。但是对于开发者来说尤其是在公司进行开发还是有一些弊端的。首先是收费,对于企业合作不知道收费具体怎么样,但对开发者自己来说确实收费还是比较高的(而且我一直也没搞懂这个收费是个怎么个收法,虽然他们说是按照token,但是总感觉有时候收费少了有时候收费多了)。其次是网络环境,比较优秀的解法是外部亚马逊服务器部署相关服务。最后是token数量,这个是比较硬伤的东西。
于是国内的大模型就成了我较好的选择,我的需求不仅仅是简单的问答,而是需要结合prompt来使用,同时因为我的输入内容比较大需要借助langchain内部的map_reduce来对我的整个提问流程进行一个整合,所以进行了星火Spark 接入langchain。话不多说,上代码。
- import _thread as thread
- import base64
- import datetime
- import hashlib
- import hmac
- import json
- import ssl
- import websocket
- import langchain
- import logging
- from config import SPARK_APPID, SPARK_API_KEY, SPARK_API_SECRET
- from urllib.parse import urlparse
- from datetime import datetime
- from time import mktime
- from urllib.parse import urlencode
- from wsgiref.handlers import format_date_time
- from typing import Optional, List, Dict, Mapping, Any
- from langchain.llms.base import LLM
- from langchain.cache import InMemoryCache
-
- logging.basicConfig(level=logging.INFO)
- # 启动llm的缓存
- langchain.llm_cache = InMemoryCache()
- result_list = []
-
-
- def _construct_query(prompt, temperature, max_tokens):
- data = {
- "header": {
- "app_id": SPARK_APPID,
- "uid": "1234"
- },
- "parameter": {
- "chat": {
- "domain": "general",
- "random_threshold": temperature,
- "max_tokens": max_tokens,
- "auditing": "default"
- }
- },
- "payload": {
- "message": {
- "text": [
- {"role": "user", "content": prompt}
- ]
- }
- }
- }
- return data
-
-
- def _run(ws, *args):
- data = json.dumps(
- _construct_query(prompt=ws.question, temperature=ws.temperature, max_tokens=ws.max_tokens))
- # print (data)
- ws.send(data)
-
-
- def on_error(ws, error):
- print("error:", error)
-
-
- def on_close(ws):
- print("closed...")
-
-
- def on_open(ws):
- thread.start_new_thread(_run, (ws,))
-
-
- def on_message(ws, message):
- data = json.loads(message)
- code = data['header']['code']
- # print(data)
- if code != 0:
- print(f'请求错误: {code}, {data}')
- ws.close()
- else:
- choices = data["payload"]["choices"]
- status = choices["status"]
- content = choices["text"][0]["content"]
- result_list.append(content)
- if status == 2:
- ws.close()
- setattr(ws, "content", "".join(result_list))
- print(result_list)
- result_list.clear()
-
-
- class Spark(LLM):
- '''
- 根据源码解析在通过LLMS包装的时候主要重构两个部分的代码
- _call 模型调用主要逻辑,输入问题,输出模型相应结果
- _identifying_params 返回模型描述信息,通常返回一个字典,字典中包括模型的主要参数
- '''
-
- gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat" # spark官方模型提供api接口
- host = urlparse(gpt_url).netloc # host目标机器解析
- path = urlparse(gpt_url).path # 路径目标解析
- max_tokens = 1024
- temperature = 0.5
-
- # ws = websocket.WebSocketApp(url='')
-
- @property
- def _llm_type(self) -> str:
- # 模型简介
- return "Spark"
-
- def _get_url(self):
- # 获取请求路径
- now = datetime.now()
- date = format_date_time(mktime(now.timetuple()))
-
- signature_origin = "host: " + self.host + "\n"
- signature_origin += "date: " + date + "\n"
- signature_origin += "GET " + self.path + " HTTP/1.1"
-
- signature_sha = hmac.new(SPARK_API_SECRET.encode('utf-8'), signature_origin.encode('utf-8'),
- digestmod=hashlib.sha256).digest()
-
- signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
-
- authorization_origin = f'api_key="{SPARK_API_KEY}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
-
- authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
-
- v = {
- "authorization": authorization,
- "date": date,
- "host": self.host
- }
- url = self.gpt_url + '?' + urlencode(v)
- return url
-
- def _post(self, prompt):
- #模型请求响应
- websocket.enableTrace(False)
- wsUrl = self._get_url()
- ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error,
- on_close=on_close, on_open=on_open)
- ws.question = prompt
- setattr(ws, "temperature", self.temperature)
- setattr(ws, "max_tokens", self.max_tokens)
- ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
- return ws.content if hasattr(ws, "content") else ""
-
- def _call(self, prompt: str,
- stop: Optional[List[str]] = None) -> str:
- # 启动关键的函数
- content = self._post(prompt)
- # content = "这是一个测试"
- return content
-
- @property
- def _identifying_params(self) -> Mapping[str, Any]:
- """
- Get the identifying parameters.
- """
- _param_dict = {
- "url": self.gpt_url
- }
- return _param_dict
-
-
- if __name__ == "__main__":
- llm = Spark(temperature=0.9)
- # data =json.dumps(llm._construct_query(prompt="你好啊", temperature=llm.temperature, max_tokens=llm.max_tokens))
- # print (data)
- # print (type(data))
- result = llm("你好啊", stop=["you"])
- print(result)

有问题的小伙伴欢迎留言。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。