community[patch]: add stop parameter support to volcengine maas (#19052)

- **Description:** add stop parameter to volcengine maas model
- **Dependencies:** no

---------

Co-authored-by: 江鹏飞 <jiangpengfei.jiangpf@bytedance.com>
This commit is contained in:
Pengfei Jiang 2024-03-17 09:58:50 +08:00 committed by GitHub
parent bcc771e37c
commit 514fe80778
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 0 deletions

View File

@ -112,6 +112,8 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if stop is not None:
kwargs["stop"] = stop
params = self._convert_prompt_msg_params(messages, **kwargs)
for res in self.client.stream_chat(params):
if res:
@ -133,6 +135,8 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
if stop is not None:
kwargs["stop"] = stop
params = self._convert_prompt_msg_params(messages, **kwargs)
res = self.client.chat(params)
msg = convert_dict_to_message(res)

View File

@ -49,6 +49,28 @@ def test_stream() -> None:
assert isinstance(response.content, str)
def test_stop() -> None:
"""Test that stop works."""
chat = VolcEngineMaasChat(
model="skylark2-pro-4k", model_version="1.2", streaming=True
)
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
response = chat(
messages=[
HumanMessage(content="repeat: hello world"),
AIMessage(content="hello world"),
HumanMessage(content="repeat: hello world"),
],
stream=True,
callbacks=callback_manager,
stop=["world"],
)
assert callback_handler.llm_streams > 0
assert isinstance(response.content, str)
assert response.content.rstrip() == "hello"
def test_multiple_messages() -> None:
"""Tests multiple messages works."""
chat = VolcEngineMaasChat()