mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
community[patch]: truncate zhipuai temperature
and top_p
parameters to [0.01, 0.99] (#20261)
ZhipuAI API only accepts `temperature` parameter between `(0, 1)` open interval, and if `0` is passed, it responds with status code `400`. However, 0 and 1 is often accepted by other APIs, for example, OpenAI allows `[0, 2]` for temperature closed range. This PR truncates temperature parameter passed to `[0.01, 0.99]` to improve the compatibility between langchain's ecosystem's and ZhipuAI (e.g., ragas `evaluate` often generates temperature 0, which results in a lot of 400 invalid responses). The PR also truncates `top_p` parameter since it has the same restriction. Reference: [glm-4 doc](https://open.bigmodel.cn/dev/api#glm-4) (which unfortunately is in Chinese though). --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
d5c22b80a5
commit
dd5139e304
@ -148,6 +148,20 @@ def _convert_delta_to_message_chunk(
|
|||||||
return default_class(content=content)
|
return default_class(content=content)
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_params(payload: Dict[str, Any]) -> None:
|
||||||
|
"""Truncate temperature and top_p parameters between [0.01, 0.99].
|
||||||
|
|
||||||
|
ZhipuAI only support temperature / top_p between (0, 1) open interval,
|
||||||
|
so we truncate them to [0.01, 0.99].
|
||||||
|
"""
|
||||||
|
temperature = payload.get("temperature")
|
||||||
|
top_p = payload.get("top_p")
|
||||||
|
if temperature is not None:
|
||||||
|
payload["temperature"] = max(0.01, min(0.99, temperature))
|
||||||
|
if top_p is not None:
|
||||||
|
payload["top_p"] = max(0.01, min(0.99, top_p))
|
||||||
|
|
||||||
|
|
||||||
class ChatZhipuAI(BaseChatModel):
|
class ChatZhipuAI(BaseChatModel):
|
||||||
"""
|
"""
|
||||||
`ZhipuAI` large language chat models API.
|
`ZhipuAI` large language chat models API.
|
||||||
@ -213,7 +227,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
model_name: Optional[str] = Field(default="glm-4", alias="model")
|
model_name: Optional[str] = Field(default="glm-4", alias="model")
|
||||||
"""
|
"""
|
||||||
Model name to use, see 'https://open.bigmodel.cn/dev/api#language'.
|
Model name to use, see 'https://open.bigmodel.cn/dev/api#language'.
|
||||||
or you can use any finetune model of glm series.
|
Alternatively, you can use any fine-tuned model from the GLM series.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
temperature: float = 0.95
|
temperature: float = 0.95
|
||||||
@ -309,6 +323,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
"messages": message_dicts,
|
"messages": message_dicts,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
}
|
}
|
||||||
|
_truncate_params(payload)
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
@ -334,6 +349,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
raise ValueError("Did not find zhipu_api_base.")
|
raise ValueError("Did not find zhipu_api_base.")
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
|
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
|
||||||
|
_truncate_params(payload)
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
@ -394,6 +410,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
"messages": message_dicts,
|
"messages": message_dicts,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
}
|
}
|
||||||
|
_truncate_params(payload)
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
@ -418,6 +435,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
raise ValueError("Did not find zhipu_api_base.")
|
raise ValueError("Did not find zhipu_api_base.")
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
|
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
|
||||||
|
_truncate_params(payload)
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
|
Loading…
Reference in New Issue
Block a user