From 514fe807784f520449914a64ffc983538fa6743d Mon Sep 17 00:00:00 2001 From: Pengfei Jiang Date: Sun, 17 Mar 2024 09:58:50 +0800 Subject: [PATCH] community[patch]: add stop parameter support to volcengine maas (#19052) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - **Description:** add stop parameter to volcengine maas model - **Dependencies:** no --------- Co-authored-by: 江鹏飞 --- .../chat_models/volcengine_maas.py | 4 ++++ .../chat_models/test_volcengine_maas.py | 22 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/libs/community/langchain_community/chat_models/volcengine_maas.py b/libs/community/langchain_community/chat_models/volcengine_maas.py index df348971de3..687a540bb30 100644 --- a/libs/community/langchain_community/chat_models/volcengine_maas.py +++ b/libs/community/langchain_community/chat_models/volcengine_maas.py @@ -112,6 +112,8 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + if stop is not None: + kwargs["stop"] = stop params = self._convert_prompt_msg_params(messages, **kwargs) for res in self.client.stream_chat(params): if res: @@ -133,6 +135,8 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase): for chunk in self._stream(messages, stop, run_manager, **kwargs): completion += chunk.text else: + if stop is not None: + kwargs["stop"] = stop params = self._convert_prompt_msg_params(messages, **kwargs) res = self.client.chat(params) msg = convert_dict_to_message(res) diff --git a/libs/community/tests/integration_tests/chat_models/test_volcengine_maas.py b/libs/community/tests/integration_tests/chat_models/test_volcengine_maas.py index 2aa84b30de4..4701c504e73 100644 --- a/libs/community/tests/integration_tests/chat_models/test_volcengine_maas.py +++ b/libs/community/tests/integration_tests/chat_models/test_volcengine_maas.py @@ -49,6 +49,28 @@ def test_stream() -> None: assert isinstance(response.content, str) +def test_stop() -> None: + """Test that stop works.""" + chat = VolcEngineMaasChat( + model="skylark2-pro-4k", model_version="1.2", streaming=True + ) + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + response = chat( + messages=[ + HumanMessage(content="repeat: hello world"), + AIMessage(content="hello world"), + HumanMessage(content="repeat: hello world"), + ], + stream=True, + callbacks=callback_manager, + stop=["world"], + ) + assert callback_handler.llm_streams > 0 + assert isinstance(response.content, str) + assert response.content.rstrip() == "hello" + + def test_multiple_messages() -> None: """Tests multiple messages works.""" chat = VolcEngineMaasChat()