feat(model): Support new zhipuai SDK (#1592)

Co-authored-by: yyhhyy <yyhhyyyyyy@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
yyhhyy 2024-06-04 19:04:52 +08:00 committed by GitHub
parent 85bf64e9a2
commit c3c063683c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
import os
from concurrent.futures import Executor
from typing import Iterator, Optional
@ -37,23 +38,37 @@ class ZhipuLLMClient(ProxyLLMClient):
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_alias: Optional[str] = "zhipu_proxyllm",
context_length: Optional[int] = 8192,
executor: Optional[Executor] = None,
):
try:
import zhipuai
from zhipuai import ZhipuAI
except ImportError as exc:
raise ValueError(
"Could not import python package: zhipuai "
"Please install dashscope by command `pip install zhipuai"
) from exc
if (
"No module named" in str(exc)
or "cannot find module" in str(exc).lower()
):
raise ValueError(
"The python package 'zhipuai' is not installed. "
"Please install it by running `pip install zhipuai`."
) from exc
else:
raise ValueError(
"Could not import python package: zhipuai "
"This may be due to a version that is too low. "
"Please upgrade the zhipuai package by running `pip install --upgrade zhipuai`."
) from exc
if not model:
model = CHATGLM_DEFAULT_MODEL
if api_key:
zhipuai.api_key = api_key
if not api_key:
# Compatible with DB-GPT's config
api_key = os.getenv("ZHIPU_PROXY_API_KEY")
self._model = model
self.client = ZhipuAI(api_key=api_key, base_url=api_base)
super().__init__(
model_names=[model, model_alias],
@ -84,7 +99,6 @@ class ZhipuLLMClient(ProxyLLMClient):
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
import zhipuai
request = self.local_covert_message(request, message_converter)
@ -92,18 +106,18 @@ class ZhipuLLMClient(ProxyLLMClient):
model = request.model or self._model
try:
res = zhipuai.model_api.sse_invoke(
response = self.client.chat.completions.create(
model=model,
prompt=messages,
messages=messages,
temperature=request.temperature,
# top_p=params.get("top_p"),
incremental=False,
stream=True,
)
for r in res.events():
if r.event == "add":
yield ModelOutput(text=r.data, error_code=0)
elif r.event == "error":
yield ModelOutput(text=r.data, error_code=1)
partial_text = ""
for chunk in response:
delta_content = chunk.choices[0].delta.content
partial_text += delta_content
yield ModelOutput(text=partial_text, error_code=0)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",