mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
feat(model): Support Baichuan API (#1009)
Co-authored-by: BaiChuanHelper <wintergyc@WinterGYCs-MacBook-Pro.local>
This commit is contained in:
parent
5ddc5345fa
commit
e75314f27d
@ -119,6 +119,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
|
|||||||
|
|
||||||
- 支持在线代理模型
|
- 支持在线代理模型
|
||||||
- [x] [OpenAI·ChatGPT](https://api.openai.com/)
|
- [x] [OpenAI·ChatGPT](https://api.openai.com/)
|
||||||
|
- [x] [百川·Baichuan](https://platform.baichuan-ai.com/)
|
||||||
- [x] [阿里·通义](https://www.aliyun.com/product/dashscope)
|
- [x] [阿里·通义](https://www.aliyun.com/product/dashscope)
|
||||||
- [x] [百度·文心](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
|
- [x] [百度·文心](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
|
||||||
- [x] [智谱·ChatGLM](http://open.bigmodel.cn/)
|
- [x] [智谱·ChatGLM](http://open.bigmodel.cn/)
|
||||||
|
@ -88,12 +88,10 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
# baichuan proxy
|
# baichuan proxy
|
||||||
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
|
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
|
||||||
self.bc_proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET")
|
self.bc_model_name = os.getenv("BAICHUN_MODEL_NAME", "Baichuan2-Turbo-192k")
|
||||||
self.bc_model_version = os.getenv("BAICHUN_MODEL_NAME")
|
|
||||||
if self.bc_proxy_api_key and self.bc_proxy_api_secret:
|
if self.bc_proxy_api_key and self.bc_proxy_api_secret:
|
||||||
os.environ["bc_proxyllm_proxy_api_key"] = self.bc_proxy_api_key
|
os.environ["bc_proxyllm_proxy_api_key"] = self.bc_proxy_api_key
|
||||||
os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret
|
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_name
|
||||||
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version
|
|
||||||
|
|
||||||
# gemini proxy
|
# gemini proxy
|
||||||
self.gemini_proxy_api_key = os.getenv("GEMINI_PROXY_API_KEY")
|
self.gemini_proxy_api_key = os.getenv("GEMINI_PROXY_API_KEY")
|
||||||
|
@ -1,74 +1,51 @@
|
|||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import requests
|
import requests
|
||||||
|
import json
|
||||||
from typing import List
|
from typing import List
|
||||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||||
|
from dbgpt.model.parameter import ProxyModelParameters
|
||||||
|
|
||||||
BAICHUAN_DEFAULT_MODEL = "Baichuan2-53B"
|
BAICHUAN_DEFAULT_MODEL = "Baichuan2-Turbo-192k"
|
||||||
|
|
||||||
|
|
||||||
def _calculate_md5(text: str) -> str:
|
def baichuan_generate_stream(model: ProxyModel, tokenizer=None, params=None, device=None, context_len=4096):
|
||||||
"""Calculate md5"""
|
url = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||||
md5 = hashlib.md5()
|
|
||||||
md5.update(text.encode("utf-8"))
|
|
||||||
encrypted = md5.hexdigest()
|
|
||||||
return encrypted
|
|
||||||
|
|
||||||
|
|
||||||
def _sign(data: dict, secret_key: str, timestamp: str):
|
|
||||||
data_str = json.dumps(data)
|
|
||||||
signature = _calculate_md5(secret_key + data_str + timestamp)
|
|
||||||
return signature
|
|
||||||
|
|
||||||
|
|
||||||
def baichuan_generate_stream(
|
|
||||||
model: ProxyModel, tokenizer, params, device, context_len=4096
|
|
||||||
):
|
|
||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
|
||||||
|
|
||||||
model_name = model_params.proxyllm_backend or BAICHUAN_DEFAULT_MODEL
|
model_name = model_params.proxyllm_backend or BAICHUAN_DEFAULT_MODEL
|
||||||
proxy_api_key = model_params.proxy_api_key
|
proxy_api_key = model_params.proxy_api_key
|
||||||
proxy_api_secret = model_params.proxy_api_secret
|
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
|
||||||
# Add history conversation
|
# Add history conversation
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.role == ModelMessageRoleType.HUMAN:
|
if message.role == ModelMessageRoleType.HUMAN:
|
||||||
history.append({"role": "user", "content": message.content})
|
history.append({"role": "user", "content": message.content})
|
||||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||||
history.append({"role": "system", "content": message.content})
|
# As of today, system message is not supported.
|
||||||
|
history.append({"role": "user", "content": message.content})
|
||||||
elif message.role == ModelMessageRoleType.AI:
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
history.append({"role": "assistant", "content": "message.content"})
|
history.append({"role": "assistant", "content": message.content})
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"messages": history,
|
"messages": history,
|
||||||
"parameters": {
|
"temperature": params.get("temperature", 0.3),
|
||||||
"temperature": params.get("temperature"),
|
"top_k": params.get("top_k", 5),
|
||||||
"top_k": params.get("top_k", 10),
|
"top_p": params.get("top_p", 0.85),
|
||||||
},
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
timestamp = int(time.time())
|
|
||||||
_signature = _sign(payload, proxy_api_secret, str(timestamp))
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": "Bearer " + proxy_api_key,
|
"Authorization": "Bearer " + proxy_api_key,
|
||||||
"X-BC-Request-Id": params.get("request_id") or "dbgpt",
|
|
||||||
"X-BC-Timestamp": str(timestamp),
|
|
||||||
"X-BC-Signature": _signature,
|
|
||||||
"X-BC-Sign-Algo": "MD5",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res = requests.post(url=url, json=payload, headers=headers, stream=True)
|
print(f"Sending request to {url} with model {model_name}")
|
||||||
print(f"Send request to {url} with real model {model_name}")
|
res = requests.post(url=url, json=payload, headers=headers)
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
for line in res.iter_lines():
|
for line in res.iter_lines():
|
||||||
@ -81,7 +58,27 @@ def baichuan_generate_stream(
|
|||||||
decoded_line = json_data.decode("utf-8")
|
decoded_line = json_data.decode("utf-8")
|
||||||
if decoded_line.lower() != "[DONE]".lower():
|
if decoded_line.lower() != "[DONE]".lower():
|
||||||
obj = json.loads(json_data)
|
obj = json.loads(json_data)
|
||||||
if obj["data"]["messages"][0].get("content") is not None:
|
if obj["choices"][0]["delta"].get("content") is not None:
|
||||||
content = obj["data"]["messages"][0].get("content")
|
content = obj["choices"][0]["delta"].get("content")
|
||||||
text += content
|
text += content
|
||||||
yield text
|
yield text
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model_params = ProxyModelParameters(
|
||||||
|
model_name="not-used",
|
||||||
|
model_path="not-used",
|
||||||
|
proxy_server_url="not-used",
|
||||||
|
proxy_api_key="YOUR_BAICHUAN_API_KEY",
|
||||||
|
proxyllm_backend="Baichuan2-Turbo-192k"
|
||||||
|
)
|
||||||
|
final_text = ""
|
||||||
|
for part in baichuan_generate_stream(
|
||||||
|
model=ProxyModel(model_params=model_params),
|
||||||
|
params={"messages": [ModelMessage(
|
||||||
|
role=ModelMessageRoleType.HUMAN,
|
||||||
|
content="背诵《论语》第一章")]}):
|
||||||
|
final_text = part
|
||||||
|
print(final_text)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -67,6 +67,7 @@ In DB-GPT, seamless support for FastChat, vLLM and llama.cpp is directly provide
|
|||||||
|
|
||||||
#### Proxy Models
|
#### Proxy Models
|
||||||
- [OpenAI·ChatGPT](https://api.openai.com/)
|
- [OpenAI·ChatGPT](https://api.openai.com/)
|
||||||
|
- [百川·Baichuan](https://platform.baichuan-ai.com/)
|
||||||
- [Alibaba·通义](https://www.aliyun.com/product/dashscope)
|
- [Alibaba·通义](https://www.aliyun.com/product/dashscope)
|
||||||
- [Google·Bard](https://bard.google.com/)
|
- [Google·Bard](https://bard.google.com/)
|
||||||
- [Baidu·文心](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
|
- [Baidu·文心](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
|
||||||
|
Loading…
Reference in New Issue
Block a user