mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-25 11:29:29 +00:00
feat(model): Support new zhipuai SDK (#1592)
Co-authored-by: yyhhyy <yyhhyyyyyy@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
parent
85bf64e9a2
commit
c3c063683c
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from concurrent.futures import Executor
|
||||
from typing import Iterator, Optional
|
||||
|
||||
@ -37,23 +38,37 @@ class ZhipuLLMClient(ProxyLLMClient):
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model_alias: Optional[str] = "zhipu_proxyllm",
|
||||
context_length: Optional[int] = 8192,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
try:
|
||||
import zhipuai
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: zhipuai "
|
||||
"Please install dashscope by command `pip install zhipuai"
|
||||
) from exc
|
||||
if (
|
||||
"No module named" in str(exc)
|
||||
or "cannot find module" in str(exc).lower()
|
||||
):
|
||||
raise ValueError(
|
||||
"The python package 'zhipuai' is not installed. "
|
||||
"Please install it by running `pip install zhipuai`."
|
||||
) from exc
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not import python package: zhipuai "
|
||||
"This may be due to a version that is too low. "
|
||||
"Please upgrade the zhipuai package by running `pip install --upgrade zhipuai`."
|
||||
) from exc
|
||||
if not model:
|
||||
model = CHATGLM_DEFAULT_MODEL
|
||||
if api_key:
|
||||
zhipuai.api_key = api_key
|
||||
if not api_key:
|
||||
# Compatible with DB-GPT's config
|
||||
api_key = os.getenv("ZHIPU_PROXY_API_KEY")
|
||||
|
||||
self._model = model
|
||||
self.client = ZhipuAI(api_key=api_key, base_url=api_base)
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
@ -84,7 +99,6 @@ class ZhipuLLMClient(ProxyLLMClient):
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> Iterator[ModelOutput]:
|
||||
import zhipuai
|
||||
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
|
||||
@ -92,18 +106,18 @@ class ZhipuLLMClient(ProxyLLMClient):
|
||||
|
||||
model = request.model or self._model
|
||||
try:
|
||||
res = zhipuai.model_api.sse_invoke(
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
prompt=messages,
|
||||
messages=messages,
|
||||
temperature=request.temperature,
|
||||
# top_p=params.get("top_p"),
|
||||
incremental=False,
|
||||
stream=True,
|
||||
)
|
||||
for r in res.events():
|
||||
if r.event == "add":
|
||||
yield ModelOutput(text=r.data, error_code=0)
|
||||
elif r.event == "error":
|
||||
yield ModelOutput(text=r.data, error_code=1)
|
||||
partial_text = ""
|
||||
for chunk in response:
|
||||
delta_content = chunk.choices[0].delta.content
|
||||
partial_text += delta_content
|
||||
yield ModelOutput(text=partial_text, error_code=0)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
|
Loading…
Reference in New Issue
Block a user