From a662468ddeeb8a500efaa0c87457a4fbef023e70 Mon Sep 17 00:00:00 2001 From: ligang-super <117738283+ligang-super@users.noreply.github.com> Date: Fri, 29 Mar 2024 02:21:49 +0800 Subject: [PATCH] community[patch]: Fix the error of Baidu Qianfan not passing the stop parameter (#18666) - [x] **PR title**: "community: fix baidu qianfan missing stop parameter" - [x] **PR message**: - **Description: Baidu Qianfan lost the stop parameter when requesting service due to extracting it from kwargs. This bug can cause the agent to receive incorrect results --------- Co-authored-by: ligang33 Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur --- .../langchain_community/chat_models/baidu_qianfan_endpoint.py | 4 ++++ .../langchain_community/llms/baidu_qianfan_endpoint.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index e91c8b7c562..801364c4dff 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -271,6 +271,7 @@ class QianfanChatEndpoint(BaseChatModel): }, ) params = self._convert_prompt_msg_params(messages, **kwargs) + params["stop"] = stop response_payload = self.client.do(**params) lc_msg = _convert_dict_to_message(response_payload) gen = ChatGeneration( @@ -316,6 +317,7 @@ class QianfanChatEndpoint(BaseChatModel): }, ) params = self._convert_prompt_msg_params(messages, **kwargs) + params["stop"] = stop response_payload = await self.client.ado(**params) lc_msg = _convert_dict_to_message(response_payload) generations = [] @@ -339,6 +341,7 @@ class QianfanChatEndpoint(BaseChatModel): **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: params = self._convert_prompt_msg_params(messages, **kwargs) + params["stop"] = stop params["stream"] = True for res in self.client.do(**params): if res: @@ -365,6 +368,7 @@ class QianfanChatEndpoint(BaseChatModel): **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: params = self._convert_prompt_msg_params(messages, **kwargs) + params["stop"] = stop params["stream"] = True async for res in await self.client.ado(**params): if res: diff --git a/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py b/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py index 2ef7baba68b..1a611d2656f 100644 --- a/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py @@ -180,6 +180,7 @@ class QianfanLLMEndpoint(LLM): completion += chunk.text return completion params = self._convert_prompt_msg_params(prompt, **kwargs) + params["stop"] = stop response_payload = self.client.do(**params) return response_payload["result"] @@ -198,6 +199,7 @@ class QianfanLLMEndpoint(LLM): return completion params = self._convert_prompt_msg_params(prompt, **kwargs) + params["stop"] = stop response_payload = await self.client.ado(**params) return response_payload["result"] @@ -210,6 +212,7 @@ class QianfanLLMEndpoint(LLM): **kwargs: Any, ) -> Iterator[GenerationChunk]: params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True}) + params["stop"] = stop for res in self.client.do(**params): if res: chunk = GenerationChunk(text=res["result"]) @@ -225,6 +228,7 @@ class QianfanLLMEndpoint(LLM): **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True}) + params["stop"] = stop async for res in await self.client.ado(**params): if res: chunk = GenerationChunk(text=res["result"])