mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
community[patch]: add stop parameter support to volcengine maas (#19052)
- **Description:** add stop parameter to volcengine maas model - **Dependencies:** no --------- Co-authored-by: 江鹏飞 <jiangpengfei.jiangpf@bytedance.com>
This commit is contained in:
parent
bcc771e37c
commit
514fe80778
@ -112,6 +112,8 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
if stop is not None:
|
||||||
|
kwargs["stop"] = stop
|
||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
for res in self.client.stream_chat(params):
|
for res in self.client.stream_chat(params):
|
||||||
if res:
|
if res:
|
||||||
@ -133,6 +135,8 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
|||||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||||
completion += chunk.text
|
completion += chunk.text
|
||||||
else:
|
else:
|
||||||
|
if stop is not None:
|
||||||
|
kwargs["stop"] = stop
|
||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
res = self.client.chat(params)
|
res = self.client.chat(params)
|
||||||
msg = convert_dict_to_message(res)
|
msg = convert_dict_to_message(res)
|
||||||
|
@ -49,6 +49,28 @@ def test_stream() -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop() -> None:
|
||||||
|
"""Test that stop works."""
|
||||||
|
chat = VolcEngineMaasChat(
|
||||||
|
model="skylark2-pro-4k", model_version="1.2", streaming=True
|
||||||
|
)
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
response = chat(
|
||||||
|
messages=[
|
||||||
|
HumanMessage(content="repeat: hello world"),
|
||||||
|
AIMessage(content="hello world"),
|
||||||
|
HumanMessage(content="repeat: hello world"),
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
callbacks=callback_manager,
|
||||||
|
stop=["world"],
|
||||||
|
)
|
||||||
|
assert callback_handler.llm_streams > 0
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
assert response.content.rstrip() == "hello"
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_messages() -> None:
|
def test_multiple_messages() -> None:
|
||||||
"""Tests multiple messages works."""
|
"""Tests multiple messages works."""
|
||||||
chat = VolcEngineMaasChat()
|
chat = VolcEngineMaasChat()
|
||||||
|
Loading…
Reference in New Issue
Block a user