From dd5139e304062a1abd1237d37b22c7dee9a866d2 Mon Sep 17 00:00:00 2001 From: Congyu <52687642+Congyuwang@users.noreply.github.com> Date: Fri, 19 Apr 2024 09:31:30 +0800 Subject: [PATCH] community[patch]: truncate zhipuai `temperature` and `top_p` parameters to [0.01, 0.99] (#20261) ZhipuAI API only accepts `temperature` parameter between `(0, 1)` open interval, and if `0` is passed, it responds with status code `400`. However, 0 and 1 is often accepted by other APIs, for example, OpenAI allows `[0, 2]` for temperature closed range. This PR truncates temperature parameter passed to `[0.01, 0.99]` to improve the compatibility between langchain's ecosystem's and ZhipuAI (e.g., ragas `evaluate` often generates temperature 0, which results in a lot of 400 invalid responses). The PR also truncates `top_p` parameter since it has the same restriction. Reference: [glm-4 doc](https://open.bigmodel.cn/dev/api#glm-4) (which unfortunately is in Chinese though). --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- .../chat_models/zhipuai.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/chat_models/zhipuai.py b/libs/community/langchain_community/chat_models/zhipuai.py index 6b98ff54e2b..5d9dd0eb887 100644 --- a/libs/community/langchain_community/chat_models/zhipuai.py +++ b/libs/community/langchain_community/chat_models/zhipuai.py @@ -148,6 +148,20 @@ def _convert_delta_to_message_chunk( return default_class(content=content) +def _truncate_params(payload: Dict[str, Any]) -> None: + """Truncate temperature and top_p parameters between [0.01, 0.99]. + + ZhipuAI only support temperature / top_p between (0, 1) open interval, + so we truncate them to [0.01, 0.99]. + """ + temperature = payload.get("temperature") + top_p = payload.get("top_p") + if temperature is not None: + payload["temperature"] = max(0.01, min(0.99, temperature)) + if top_p is not None: + payload["top_p"] = max(0.01, min(0.99, top_p)) + + class ChatZhipuAI(BaseChatModel): """ `ZhipuAI` large language chat models API. @@ -213,7 +227,7 @@ class ChatZhipuAI(BaseChatModel): model_name: Optional[str] = Field(default="glm-4", alias="model") """ Model name to use, see 'https://open.bigmodel.cn/dev/api#language'. - or you can use any finetune model of glm series. + Alternatively, you can use any fine-tuned model from the GLM series. """ temperature: float = 0.95 @@ -309,6 +323,7 @@ class ChatZhipuAI(BaseChatModel): "messages": message_dicts, "stream": False, } + _truncate_params(payload) headers = { "Authorization": _get_jwt_token(self.zhipuai_api_key), "Accept": "application/json", @@ -334,6 +349,7 @@ class ChatZhipuAI(BaseChatModel): raise ValueError("Did not find zhipu_api_base.") message_dicts, params = self._create_message_dicts(messages, stop) payload = {**params, **kwargs, "messages": message_dicts, "stream": True} + _truncate_params(payload) headers = { "Authorization": _get_jwt_token(self.zhipuai_api_key), "Accept": "application/json", @@ -394,6 +410,7 @@ class ChatZhipuAI(BaseChatModel): "messages": message_dicts, "stream": False, } + _truncate_params(payload) headers = { "Authorization": _get_jwt_token(self.zhipuai_api_key), "Accept": "application/json", @@ -418,6 +435,7 @@ class ChatZhipuAI(BaseChatModel): raise ValueError("Did not find zhipu_api_base.") message_dicts, params = self._create_message_dicts(messages, stop) payload = {**params, **kwargs, "messages": message_dicts, "stream": True} + _truncate_params(payload) headers = { "Authorization": _get_jwt_token(self.zhipuai_api_key), "Accept": "application/json",