diff --git a/dbgpt/model/proxy/llms/zhipu.py b/dbgpt/model/proxy/llms/zhipu.py index 6e9d36224..dc1c9b666 100644 --- a/dbgpt/model/proxy/llms/zhipu.py +++ b/dbgpt/model/proxy/llms/zhipu.py @@ -1,3 +1,4 @@ +import os from concurrent.futures import Executor from typing import Iterator, Optional @@ -37,23 +38,37 @@ class ZhipuLLMClient(ProxyLLMClient): self, model: Optional[str] = None, api_key: Optional[str] = None, + api_base: Optional[str] = None, model_alias: Optional[str] = "zhipu_proxyllm", context_length: Optional[int] = 8192, executor: Optional[Executor] = None, ): try: - import zhipuai + from zhipuai import ZhipuAI except ImportError as exc: - raise ValueError( - "Could not import python package: zhipuai " - "Please install dashscope by command `pip install zhipuai" - ) from exc + if ( + "No module named" in str(exc) + or "cannot find module" in str(exc).lower() + ): + raise ValueError( + "The python package 'zhipuai' is not installed. " + "Please install it by running `pip install zhipuai`." + ) from exc + else: + raise ValueError( + "Could not import python package: zhipuai " + "This may be due to a version that is too low. " + "Please upgrade the zhipuai package by running `pip install --upgrade zhipuai`." + ) from exc if not model: model = CHATGLM_DEFAULT_MODEL - if api_key: - zhipuai.api_key = api_key + if not api_key: + # Compatible with DB-GPT's config + api_key = os.getenv("ZHIPU_PROXY_API_KEY") + self._model = model + self.client = ZhipuAI(api_key=api_key, base_url=api_base) super().__init__( model_names=[model, model_alias], @@ -84,7 +99,6 @@ class ZhipuLLMClient(ProxyLLMClient): request: ModelRequest, message_converter: Optional[MessageConverter] = None, ) -> Iterator[ModelOutput]: - import zhipuai request = self.local_covert_message(request, message_converter) @@ -92,18 +106,18 @@ class ZhipuLLMClient(ProxyLLMClient): model = request.model or self._model try: - res = zhipuai.model_api.sse_invoke( + response = self.client.chat.completions.create( model=model, - prompt=messages, + messages=messages, temperature=request.temperature, # top_p=params.get("top_p"), - incremental=False, + stream=True, ) - for r in res.events(): - if r.event == "add": - yield ModelOutput(text=r.data, error_code=0) - elif r.event == "error": - yield ModelOutput(text=r.data, error_code=1) + partial_text = "" + for chunk in response: + delta_content = chunk.choices[0].delta.content + partial_text += delta_content + yield ModelOutput(text=partial_text, error_code=0) except Exception as e: return ModelOutput( text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",