mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-04 10:00:17 +00:00
feat: + tongyi and wenxin
This commit is contained in:
parent
32ae362eac
commit
14cfc34c72
49
examples/tongyi.py
Normal file
49
examples/tongyi.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import dashscope
|
||||||
|
import requests
|
||||||
|
from http import HTTPStatus
|
||||||
|
from dashscope import Generation
|
||||||
|
|
||||||
|
def call_with_messages():
|
||||||
|
messages = [{'role': 'system', 'content': '你是生活助手机器人。'},
|
||||||
|
{'role': 'user', 'content': '如何做西红柿鸡蛋?'}]
|
||||||
|
gen = Generation()
|
||||||
|
response = gen.call(
|
||||||
|
Generation.Models.qwen_turbo,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
top_p=0.8,
|
||||||
|
result_format='message', # set the result to be "message" format.
|
||||||
|
)
|
||||||
|
|
||||||
|
for response in response:
|
||||||
|
# The response status_code is HTTPStatus.OK indicate success,
|
||||||
|
# otherwise indicate request is failed, you can get error code
|
||||||
|
# and message from code and message.
|
||||||
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
print(response.output) # The output text
|
||||||
|
print(response.usage) # The usage information
|
||||||
|
else:
|
||||||
|
print(response.code) # The error code.
|
||||||
|
print(response.message) # The error message.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def build_access_token(api_key: str, secret_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate Access token according AK, SK
|
||||||
|
"""
|
||||||
|
|
||||||
|
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||||
|
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||||
|
|
||||||
|
res = requests.get(url=url, params=params)
|
||||||
|
|
||||||
|
if res.status_code == 200:
|
||||||
|
return res.json().get("access_token")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
call_with_messages()
|
@ -46,10 +46,13 @@ class Config(metaclass=Singleton):
|
|||||||
# This is a proxy server, just for test_py. we will remove this later.
|
# This is a proxy server, just for test_py. we will remove this later.
|
||||||
self.proxy_api_key = os.getenv("PROXY_API_KEY")
|
self.proxy_api_key = os.getenv("PROXY_API_KEY")
|
||||||
self.bard_proxy_api_key = os.getenv("BARD_PROXY_API_KEY")
|
self.bard_proxy_api_key = os.getenv("BARD_PROXY_API_KEY")
|
||||||
|
|
||||||
# In order to be compatible with the new and old model parameter design
|
# In order to be compatible with the new and old model parameter design
|
||||||
if self.bard_proxy_api_key:
|
if self.bard_proxy_api_key:
|
||||||
os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key
|
os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key
|
||||||
|
|
||||||
|
self.tongyi_api_key = os.getenv("TONGYI_PROXY_API_KEY")
|
||||||
|
|
||||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||||
|
|
||||||
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||||
|
@ -1,7 +1,66 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def tongyi_generate_stream(
|
def tongyi_generate_stream(
|
||||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||||
):
|
):
|
||||||
yield "tongyi LLM was not supported!"
|
|
||||||
|
import dashscope
|
||||||
|
from dashscope import Generation
|
||||||
|
model_params = model.get_params()
|
||||||
|
print(f"Model: {model}, model_params: {model_params}")
|
||||||
|
|
||||||
|
# proxy_api_key = model_params.proxy_api_key # // TODO Set this according env
|
||||||
|
dashscope.api_key = os.getenv("TONGYI_PROXY_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
proxyllm_backend = model_params.proxyllm_backend
|
||||||
|
if not proxyllm_backend:
|
||||||
|
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
|
||||||
|
|
||||||
|
history = []
|
||||||
|
|
||||||
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
# Add history conversation
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message.role == ModelMessageRoleType.HUMAN:
|
||||||
|
history.append({"role": "user", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||||
|
history.append({"role": "system", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
|
history.append({"role": "assistant", "content": message.content})
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
temp_his = history[::-1]
|
||||||
|
last_user_input = None
|
||||||
|
for m in temp_his:
|
||||||
|
if m["role"] == "user":
|
||||||
|
last_user_input = m
|
||||||
|
break
|
||||||
|
|
||||||
|
if last_user_input:
|
||||||
|
history.remove(last_user_input)
|
||||||
|
history.append(last_user_input)
|
||||||
|
|
||||||
|
print("history", history)
|
||||||
|
gen = Generation()
|
||||||
|
res = gen.call(
|
||||||
|
proxyllm_backend,
|
||||||
|
messages=history,
|
||||||
|
top_p=params.get("top_p", 0.8),
|
||||||
|
stream=True,
|
||||||
|
result_format='message'
|
||||||
|
)
|
||||||
|
|
||||||
|
for r in res:
|
||||||
|
if r["output"]["choices"][0]["message"].get("content") is not None:
|
||||||
|
content = r["output"]["choices"][0]["message"].get("content")
|
||||||
|
yield content
|
||||||
|
|
||||||
|
@ -1,7 +1,90 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
from cachetools import cached, TTLCache
|
||||||
|
|
||||||
|
@cached(TTLCache(1, 1800))
|
||||||
|
def _build_access_token(api_key: str, secret_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate Access token according AK, SK
|
||||||
|
"""
|
||||||
|
|
||||||
|
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||||
|
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||||
|
|
||||||
|
res = requests.get(url=url, params=params)
|
||||||
|
|
||||||
|
if res.status_code == 200:
|
||||||
|
return res.json().get("access_token")
|
||||||
|
|
||||||
def wenxin_generate_stream(
|
def wenxin_generate_stream(
|
||||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||||
):
|
):
|
||||||
yield "wenxin LLM is not supported!"
|
MODEL_VERSION = {
|
||||||
|
"ERNIE-Bot": "completions",
|
||||||
|
"ERNIE-Bot-turbo": "eb-instant",
|
||||||
|
}
|
||||||
|
|
||||||
|
model_params = model.get_params()
|
||||||
|
model_name = os.getenv("WEN_XIN_MODEL_VERSION")
|
||||||
|
model_version = MODEL_VERSION.get(model_name)
|
||||||
|
if not model_version:
|
||||||
|
yield f"Unsupport model version {model_name}"
|
||||||
|
|
||||||
|
proxy_api_key = os.getenv("WEN_XIN_API_KEY")
|
||||||
|
proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY")
|
||||||
|
access_token = _build_access_token(proxy_api_key, proxy_api_secret)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy_server_url = f'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}'
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
yield "Failed to get access token. please set the correct api_key and secret key."
|
||||||
|
|
||||||
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
|
||||||
|
history = []
|
||||||
|
# Add history conversation
|
||||||
|
for message in messages:
|
||||||
|
if message.role == ModelMessageRoleType.HUMAN:
|
||||||
|
history.append({"role": "user", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||||
|
history.append({"role": "system", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
|
history.append({"role": "assistant", "content": message.content})
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": history,
|
||||||
|
"temperature": params.get("temperature"),
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
res = requests.post(proxy_server_url, headers=headers, json=payload, stream=True)
|
||||||
|
print(f"Send request to {proxy_server_url} with real model {model_name}")
|
||||||
|
for line in res.iter_lines():
|
||||||
|
if line:
|
||||||
|
if not line.startswith(b"data: "):
|
||||||
|
error_message = line.decode("utf-8")
|
||||||
|
yield error_message
|
||||||
|
else:
|
||||||
|
json_data = line.split(b": ", 1)[1]
|
||||||
|
decoded_line = json_data.decode("utf-8")
|
||||||
|
if decoded_line.lower() != "[DONE]".lower():
|
||||||
|
obj = json.loads(json_data)
|
||||||
|
if obj["result"] is not None:
|
||||||
|
content = obj["result"]
|
||||||
|
text += content
|
||||||
|
yield text
|
||||||
|
|
||||||
|
|
||||||
|
|
0
tests/unit_tests/llms/wenxin.py
Normal file
0
tests/unit_tests/llms/wenxin.py
Normal file
Loading…
Reference in New Issue
Block a user