mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
0
dbgpt/model/proxy/__init__.py
Normal file
0
dbgpt/model/proxy/__init__.py
Normal file
3
dbgpt/model/proxy/data_privacy/mask/__init__.py
Normal file
3
dbgpt/model/proxy/data_privacy/mask/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
data masking, transform private sensitive data into mask data, based on the tool sensitive data recognition.
|
||||
"""
|
3
dbgpt/model/proxy/data_privacy/mask/masking.py
Normal file
3
dbgpt/model/proxy/data_privacy/mask/masking.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
mask the sensitive data before upload LLM inference service
|
||||
"""
|
3
dbgpt/model/proxy/data_privacy/mask/recovery.py
Normal file
3
dbgpt/model/proxy/data_privacy/mask/recovery.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
recovery the data after LLM inference
|
||||
"""
|
3
dbgpt/model/proxy/data_privacy/sensitive_detection.py
Normal file
3
dbgpt/model/proxy/data_privacy/sensitive_detection.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
a tool to discovery sensitive data
|
||||
"""
|
5
dbgpt/model/proxy/llms/__init__.py
Normal file
5
dbgpt/model/proxy/llms/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
There are several limitations to privatizing large models: high deployment costs and poor performance.
|
||||
In scenarios where data data_privacy requirements are relatively low, connecting with commercial large models can enable
|
||||
rapid and efficient product implementation with high quality.
|
||||
"""
|
87
dbgpt/model/proxy/llms/baichuan.py
Normal file
87
dbgpt/model/proxy/llms/baichuan.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
from typing import List
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
BAICHUAN_DEFAULT_MODEL = "Baichuan2-53B"
|
||||
|
||||
|
||||
def _calculate_md5(text: str) -> str:
|
||||
"""Calculate md5"""
|
||||
md5 = hashlib.md5()
|
||||
md5.update(text.encode("utf-8"))
|
||||
encrypted = md5.hexdigest()
|
||||
return encrypted
|
||||
|
||||
|
||||
def _sign(data: dict, secret_key: str, timestamp: str):
|
||||
data_str = json.dumps(data)
|
||||
signature = _calculate_md5(secret_key + data_str + timestamp)
|
||||
return signature
|
||||
|
||||
|
||||
def baichuan_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=4096
|
||||
):
|
||||
model_params = model.get_params()
|
||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
|
||||
model_name = model_params.proxyllm_backend or BAICHUAN_DEFAULT_MODEL
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_api_secret = model_params.proxy_api_secret
|
||||
|
||||
history = []
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": "message.content"})
|
||||
else:
|
||||
pass
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"messages": history,
|
||||
"parameters": {
|
||||
"temperature": params.get("temperature"),
|
||||
"top_k": params.get("top_k", 10),
|
||||
},
|
||||
}
|
||||
|
||||
timestamp = int(time.time())
|
||||
_signature = _sign(payload, proxy_api_secret, str(timestamp))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + proxy_api_key,
|
||||
"X-BC-Request-Id": params.get("request_id") or "dbgpt",
|
||||
"X-BC-Timestamp": str(timestamp),
|
||||
"X-BC-Signature": _signature,
|
||||
"X-BC-Sign-Algo": "MD5",
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=payload, headers=headers, stream=True)
|
||||
print(f"Send request to {url} with real model {model_name}")
|
||||
|
||||
text = ""
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
if not line.startswith(b"data: "):
|
||||
error_message = line.decode("utf-8")
|
||||
yield error_message
|
||||
else:
|
||||
json_data = line.split(b": ", 1)[1]
|
||||
decoded_line = json_data.decode("utf-8")
|
||||
if decoded_line.lower() != "[DONE]".lower():
|
||||
obj = json.loads(json_data)
|
||||
if obj["data"]["messages"][0].get("content") is not None:
|
||||
content = obj["data"]["messages"][0].get("content")
|
||||
text += content
|
||||
yield text
|
61
dbgpt/model/proxy/llms/bard.py
Normal file
61
dbgpt/model/proxy/llms/bard.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import requests
|
||||
from typing import List
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
|
||||
def bard_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_server_url = model_params.proxy_server_url
|
||||
|
||||
history = []
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
temp_his = history[::-1]
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
||||
msgs = []
|
||||
for msg in history:
|
||||
if msg.get("content"):
|
||||
msgs.append(msg["content"])
|
||||
|
||||
if proxy_server_url is not None:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payloads = {"input": "\n".join(msgs)}
|
||||
response = requests.post(
|
||||
proxy_server_url, headers=headers, json=payloads, stream=False
|
||||
)
|
||||
if response.ok:
|
||||
yield response.text
|
||||
else:
|
||||
yield f"bard proxy url request failed!, response = {str(response)}"
|
||||
else:
|
||||
import bardapi
|
||||
|
||||
response = bardapi.core.Bard(proxy_api_key).get_answer("\n".join(msgs))
|
||||
|
||||
if response is not None and response.get("content") is not None:
|
||||
yield str(response["content"])
|
||||
else:
|
||||
yield f"bard response error: {str(response)}"
|
250
dbgpt/model/proxy/llms/chatgpt.py
Normal file
250
dbgpt/model/proxy/llms/chatgpt.py
Normal file
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
import logging
|
||||
import importlib.metadata as metadata
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _initialize_openai(params: ProxyModelParameters):
|
||||
try:
|
||||
import openai
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai` "
|
||||
) from exc
|
||||
|
||||
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||
|
||||
api_base = params.proxy_api_base or os.getenv(
|
||||
"OPENAI_API_TYPE",
|
||||
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
|
||||
)
|
||||
api_key = params.proxy_api_key or os.getenv(
|
||||
"OPENAI_API_KEY",
|
||||
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
|
||||
)
|
||||
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")
|
||||
|
||||
if not api_base and params.proxy_server_url:
|
||||
# Adapt previous proxy_server_url configuration
|
||||
api_base = params.proxy_server_url.split("/chat/completions")[0]
|
||||
if api_type:
|
||||
openai.api_type = api_type
|
||||
if api_base:
|
||||
openai.api_base = api_base
|
||||
if api_key:
|
||||
openai.api_key = api_key
|
||||
if api_version:
|
||||
openai.api_version = api_version
|
||||
if params.http_proxy:
|
||||
openai.proxy = params.http_proxy
|
||||
|
||||
openai_params = {
|
||||
"api_type": api_type,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"proxy": params.http_proxy,
|
||||
}
|
||||
|
||||
return openai_params
|
||||
|
||||
|
||||
def _initialize_openai_v1(params: ProxyModelParameters):
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai"
|
||||
)
|
||||
|
||||
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||
|
||||
base_url = params.proxy_api_base or os.getenv(
|
||||
"OPENAI_API_TYPE",
|
||||
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
|
||||
)
|
||||
api_key = params.proxy_api_key or os.getenv(
|
||||
"OPENAI_API_KEY",
|
||||
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
|
||||
)
|
||||
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")
|
||||
|
||||
if not base_url and params.proxy_server_url:
|
||||
# Adapt previous proxy_server_url configuration
|
||||
base_url = params.proxy_server_url.split("/chat/completions")[0]
|
||||
|
||||
proxies = params.http_proxy
|
||||
openai_params = {
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
return openai_params, api_type, api_version, proxies
|
||||
|
||||
|
||||
def _build_request(model: ProxyModel, params):
|
||||
history = []
|
||||
|
||||
model_params = model.get_params()
|
||||
logger.info(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
# Move the last user's information to the end
|
||||
temp_his = history[::-1]
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
||||
payloads = {
|
||||
"temperature": params.get("temperature"),
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
"stream": True,
|
||||
}
|
||||
proxyllm_backend = model_params.proxyllm_backend
|
||||
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||
model_params
|
||||
)
|
||||
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||
payloads["model"] = proxyllm_backend
|
||||
else:
|
||||
openai_params = _initialize_openai(model_params)
|
||||
if openai_params["api_type"] == "azure":
|
||||
# engine = "deployment_name".
|
||||
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
|
||||
payloads["engine"] = proxyllm_backend
|
||||
else:
|
||||
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||
payloads["model"] = proxyllm_backend
|
||||
|
||||
logger.info(f"Send request to real model {proxyllm_backend}")
|
||||
return history, payloads
|
||||
|
||||
|
||||
def chatgpt_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
model_params = model.get_params()
|
||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||
model_params
|
||||
)
|
||||
history, payloads = _build_request(model, params)
|
||||
if api_type == "azure":
|
||||
from openai import AzureOpenAI
|
||||
|
||||
client = AzureOpenAI(
|
||||
api_key=openai_params["api_key"],
|
||||
api_version=api_version,
|
||||
azure_endpoint=openai_params["base_url"],
|
||||
http_client=httpx.Client(proxies=proxies),
|
||||
)
|
||||
else:
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(**openai_params, http_client=httpx.Client(proxies=proxies))
|
||||
res = client.chat.completions.create(messages=history, **payloads)
|
||||
text = ""
|
||||
for r in res:
|
||||
# logger.info(str(r))
|
||||
# Azure Openai reponse may have empty choices body in the first chunk
|
||||
# to avoid index out of range error
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
if r.choices[0].delta.content is not None:
|
||||
content = r.choices[0].delta.content
|
||||
text += content
|
||||
yield text
|
||||
|
||||
else:
|
||||
import openai
|
||||
|
||||
history, payloads = _build_request(model, params)
|
||||
|
||||
res = openai.ChatCompletion.create(messages=history, **payloads)
|
||||
|
||||
text = ""
|
||||
for r in res:
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
if r["choices"][0]["delta"].get("content") is not None:
|
||||
content = r["choices"][0]["delta"]["content"]
|
||||
text += content
|
||||
yield text
|
||||
|
||||
|
||||
async def async_chatgpt_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
model_params = model.get_params()
|
||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||
model_params
|
||||
)
|
||||
history, payloads = _build_request(model, params)
|
||||
if api_type == "azure":
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
client = AsyncAzureOpenAI(
|
||||
api_key=openai_params["api_key"],
|
||||
api_version=api_version,
|
||||
azure_endpoint=openai_params["base_url"],
|
||||
http_client=httpx.AsyncClient(proxies=proxies),
|
||||
)
|
||||
else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(
|
||||
**openai_params, http_client=httpx.AsyncClient(proxies=proxies)
|
||||
)
|
||||
|
||||
res = await client.chat.completions.create(messages=history, **payloads)
|
||||
text = ""
|
||||
for r in res:
|
||||
if not r.get("choices"):
|
||||
continue
|
||||
if r.choices[0].delta.content is not None:
|
||||
content = r.choices[0].delta.content
|
||||
text += content
|
||||
yield text
|
||||
else:
|
||||
import openai
|
||||
|
||||
history, payloads = _build_request(model, params)
|
||||
|
||||
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
|
||||
|
||||
text = ""
|
||||
async for r in res:
|
||||
if not r.get("choices"):
|
||||
continue
|
||||
if r["choices"][0]["delta"].get("content") is not None:
|
||||
content = r["choices"][0]["delta"]["content"]
|
||||
text += content
|
||||
yield text
|
7
dbgpt/model/proxy/llms/claude.py
Normal file
7
dbgpt/model/proxy/llms/claude.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
|
||||
def claude_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
yield "claude LLM was not supported!"
|
9
dbgpt/model/proxy/llms/proxy_model.py
Normal file
9
dbgpt/model/proxy/llms/proxy_model.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
|
||||
|
||||
class ProxyModel:
|
||||
def __init__(self, model_params: ProxyModelParameters) -> None:
|
||||
self._model_params = model_params
|
||||
|
||||
def get_params(self) -> ProxyModelParameters:
|
||||
return self._model_params
|
153
dbgpt/model/proxy/llms/spark.py
Normal file
153
dbgpt/model/proxy/llms/spark.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import json
|
||||
import base64
|
||||
import hmac
|
||||
import hashlib
|
||||
from websockets.sync.client import connect
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
SPARK_DEFAULT_API_VERSION = "v3"
|
||||
|
||||
|
||||
def getlength(text):
|
||||
length = 0
|
||||
for content in text:
|
||||
temp = content["content"]
|
||||
leng = len(temp)
|
||||
length += leng
|
||||
return length
|
||||
|
||||
|
||||
def checklen(text):
|
||||
while getlength(text) > 8192:
|
||||
del text[0]
|
||||
return text
|
||||
|
||||
|
||||
def spark_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
model_params = model.get_params()
|
||||
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_api_secret = model_params.proxy_api_secret
|
||||
proxy_app_id = model_params.proxy_api_app_id
|
||||
|
||||
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
|
||||
url = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
domain = "generalv3"
|
||||
else:
|
||||
url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
domain = "generalv2"
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
last_user_input = None
|
||||
for index in range(len(messages) - 1, -1, -1):
|
||||
print(f"index: {index}")
|
||||
if messages[index].role == ModelMessageRoleType.HUMAN:
|
||||
last_user_input = {"role": "user", "content": messages[index].content}
|
||||
del messages[index]
|
||||
break
|
||||
|
||||
history = []
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
# There is no role for system in spark LLM
|
||||
if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
question = checklen(history + [last_user_input])
|
||||
|
||||
print('last_user_input.get("content")', last_user_input.get("content"))
|
||||
data = {
|
||||
"header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": context_len,
|
||||
"auditing": "default",
|
||||
"temperature": params.get("temperature"),
|
||||
}
|
||||
},
|
||||
"payload": {"message": {"text": question}},
|
||||
}
|
||||
|
||||
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
|
||||
request_url = spark_api.gen_url()
|
||||
return get_response(request_url, data)
|
||||
|
||||
|
||||
def get_response(request_url, data):
|
||||
with connect(request_url) as ws:
|
||||
ws.send(json.dumps(data, ensure_ascii=False))
|
||||
result = ""
|
||||
while True:
|
||||
try:
|
||||
chunk = ws.recv()
|
||||
response = json.loads(chunk)
|
||||
print("look out the response: ", response)
|
||||
choices = response.get("payload", {}).get("choices", {})
|
||||
if text := choices.get("text"):
|
||||
result += text[0]["content"]
|
||||
if choices.get("status") == 2:
|
||||
break
|
||||
except Exception:
|
||||
break
|
||||
yield result
|
||||
|
||||
|
||||
class SparkAPI:
|
||||
def __init__(
|
||||
self, appid: str, api_key: str, api_secret: str, spark_url: str
|
||||
) -> None:
|
||||
self.appid = appid
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.host = urlparse(spark_url).netloc
|
||||
self.path = urlparse(spark_url).path
|
||||
|
||||
self.spark_url = spark_url
|
||||
|
||||
def gen_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(
|
||||
self.api_secret.encode("utf-8"),
|
||||
signature_origin.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
|
||||
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
|
||||
encoding="utf-8"
|
||||
)
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.spark_url + "?" + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
73
dbgpt/model/proxy/llms/tongyi.py
Normal file
73
dbgpt/model/proxy/llms/tongyi.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def __convert_2_tongyi_messages(messages: List[ModelMessage]):
|
||||
chat_round = 0
|
||||
tongyi_messages = []
|
||||
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
last_usr_message = message.content
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
last_ai_message = message.content
|
||||
tongyi_messages.append({"role": "user", "content": last_usr_message})
|
||||
tongyi_messages.append({"role": "assistant", "content": last_ai_message})
|
||||
if len(system_messages) > 0:
|
||||
if len(system_messages) < 2:
|
||||
tongyi_messages.insert(0, {"role": "system", "content": system_messages[0]})
|
||||
else:
|
||||
tongyi_messages.append({"role": "user", "content": system_messages[1]})
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
if last_message.role == ModelMessageRoleType.HUMAN:
|
||||
tongyi_messages.append({"role": "user", "content": last_message.content})
|
||||
|
||||
return tongyi_messages
|
||||
|
||||
|
||||
def tongyi_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
import dashscope
|
||||
from dashscope import Generation
|
||||
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
dashscope.api_key = proxy_api_key
|
||||
|
||||
proxyllm_backend = model_params.proxyllm_backend
|
||||
if not proxyllm_backend:
|
||||
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
history = __convert_2_tongyi_messages(messages)
|
||||
gen = Generation()
|
||||
res = gen.call(
|
||||
proxyllm_backend,
|
||||
messages=history,
|
||||
top_p=params.get("top_p", 0.8),
|
||||
stream=True,
|
||||
result_format="message",
|
||||
)
|
||||
|
||||
for r in res:
|
||||
if r:
|
||||
if r["status_code"] == 200:
|
||||
content = r["output"]["choices"][0]["message"].get("content")
|
||||
yield content
|
||||
else:
|
||||
content = r["code"] + ":" + r["message"]
|
||||
yield content
|
117
dbgpt/model/proxy/llms/wenxin.py
Normal file
117
dbgpt/model/proxy/llms/wenxin.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import requests
|
||||
import json
|
||||
from typing import List
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from cachetools import cached, TTLCache
|
||||
|
||||
|
||||
@cached(TTLCache(1, 1800))
|
||||
def _build_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
Generate Access token according AK, SK
|
||||
"""
|
||||
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": api_key,
|
||||
"client_secret": secret_key,
|
||||
}
|
||||
|
||||
res = requests.get(url=url, params=params)
|
||||
|
||||
if res.status_code == 200:
|
||||
return res.json().get("access_token")
|
||||
|
||||
|
||||
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
||||
chat_round = 0
|
||||
wenxin_messages = []
|
||||
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
last_usr_message = message.content
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
last_ai_message = message.content
|
||||
wenxin_messages.append({"role": "user", "content": last_usr_message})
|
||||
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
|
||||
|
||||
# build last user messge
|
||||
|
||||
if len(system_messages) > 0:
|
||||
if len(system_messages) > 1:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
if last_message.role == ModelMessageRoleType.HUMAN:
|
||||
end_message = system_messages[-1] + "\n" + last_message.content
|
||||
else:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
end_message = last_message.content
|
||||
wenxin_messages.append({"role": "user", "content": end_message})
|
||||
return wenxin_messages, system_messages
|
||||
|
||||
|
||||
def wenxin_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
MODEL_VERSION = {
|
||||
"ERNIE-Bot": "completions",
|
||||
"ERNIE-Bot-turbo": "eb-instant",
|
||||
}
|
||||
|
||||
model_params = model.get_params()
|
||||
model_name = model_params.proxyllm_backend
|
||||
model_version = MODEL_VERSION.get(model_name)
|
||||
if not model_version:
|
||||
yield f"Unsupport model version {model_name}"
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_api_secret = model_params.proxy_api_secret
|
||||
access_token = _build_access_token(proxy_api_key, proxy_api_secret)
|
||||
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}"
|
||||
|
||||
if not access_token:
|
||||
yield "Failed to get access token. please set the correct api_key and secret key."
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
history, systems = __convert_2_wenxin_messages(messages)
|
||||
system = ""
|
||||
if systems and len(systems) > 0:
|
||||
system = systems[0]
|
||||
payload = {
|
||||
"messages": history,
|
||||
"system": system,
|
||||
"temperature": params.get("temperature"),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
text = ""
|
||||
res = requests.post(proxy_server_url, headers=headers, json=payload, stream=True)
|
||||
print(f"Send request to {proxy_server_url} with real model {model_name}")
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
if not line.startswith(b"data: "):
|
||||
error_message = line.decode("utf-8")
|
||||
yield error_message
|
||||
else:
|
||||
json_data = line.split(b": ", 1)[1]
|
||||
decoded_line = json_data.decode("utf-8")
|
||||
if decoded_line.lower() != "[DONE]".lower():
|
||||
obj = json.loads(json_data)
|
||||
if obj["result"] is not None:
|
||||
content = obj["result"]
|
||||
text += content
|
||||
yield text
|
101
dbgpt/model/proxy/llms/zhipu.py
Normal file
101
dbgpt/model/proxy/llms/zhipu.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from typing import List
|
||||
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
|
||||
|
||||
|
||||
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
||||
chat_round = 0
|
||||
wenxin_messages = []
|
||||
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
last_usr_message = message.content
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
last_ai_message = message.content
|
||||
wenxin_messages.append({"role": "user", "content": last_usr_message})
|
||||
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
|
||||
|
||||
# build last user messge
|
||||
|
||||
if len(system_messages) > 0:
|
||||
if len(system_messages) > 1:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
if last_message.role == ModelMessageRoleType.HUMAN:
|
||||
end_message = system_messages[-1] + "\n" + last_message.content
|
||||
else:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
end_message = last_message.content
|
||||
wenxin_messages.append({"role": "user", "content": end_message})
|
||||
return wenxin_messages, system_messages
|
||||
|
||||
|
||||
def zhipu_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
# TODO proxy model use unified config?
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
|
||||
|
||||
import zhipuai
|
||||
|
||||
zhipuai.api_key = proxy_api_key
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
# system = ""
|
||||
# if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
|
||||
# role_define = messages.pop(0)
|
||||
# system = role_define.content
|
||||
# else:
|
||||
# message = messages.pop(0)
|
||||
# if message.role == ModelMessageRoleType.HUMAN:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
# for message in messages:
|
||||
# if message.role == ModelMessageRoleType.SYSTEM:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
# # elif message.role == ModelMessageRoleType.HUMAN:
|
||||
# # history.append({"role": "user", "content": message.content})
|
||||
# elif message.role == ModelMessageRoleType.AI:
|
||||
# history.append({"role": "assistant", "content": message.content})
|
||||
# else:
|
||||
# pass
|
||||
#
|
||||
# # temp_his = history[::-1]
|
||||
# temp_his = history
|
||||
# last_user_input = None
|
||||
# for m in temp_his:
|
||||
# if m["role"] == "user":
|
||||
# last_user_input = m
|
||||
# break
|
||||
#
|
||||
# if last_user_input:
|
||||
# history.remove(last_user_input)
|
||||
# history.append(last_user_input)
|
||||
|
||||
history, systems = __convert_2_wenxin_messages(messages)
|
||||
res = zhipuai.model_api.sse_invoke(
|
||||
model=proxyllm_backend,
|
||||
prompt=history,
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
incremental=False,
|
||||
)
|
||||
for r in res.events():
|
||||
if r.event == "add":
|
||||
yield r.data
|
Reference in New Issue
Block a user