From 9d97de34ae050eb8b054caa98777e7776f6012e7 Mon Sep 17 00:00:00 2001 From: maang-h <55082429+maang-h@users.noreply.github.com> Date: Tue, 16 Jul 2024 03:17:00 +0800 Subject: [PATCH] community[patch]: Improve ChatBaichuan init args and role (#23878) - **Description:** Improve ChatBaichuan init args and role - ChatBaichuan adds `system` role - alias: `baichuan_api_base` -> `base_url` - `with_search_enhance` is deprecated - Add `max_tokens` argument --- .../langchain_community/chat_models/baichuan.py | 17 ++++++++++++++--- .../chat_models/test_baichuan.py | 11 +++++++++++ .../unit_tests/chat_models/test_baichuan.py | 17 ++++++++++++----- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/libs/community/langchain_community/chat_models/baichuan.py b/libs/community/langchain_community/chat_models/baichuan.py index ede68a14e66..b96f1cd4c2a 100644 --- a/libs/community/langchain_community/chat_models/baichuan.py +++ b/libs/community/langchain_community/chat_models/baichuan.py @@ -22,6 +22,8 @@ from langchain_core.messages import ( ChatMessageChunk, HumanMessage, HumanMessageChunk, + SystemMessage, + SystemMessageChunk, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import Field, SecretStr, root_validator @@ -44,6 +46,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} else: raise TypeError(f"Got unknown type {message}") @@ -56,6 +60,8 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: return HumanMessage(content=_dict["content"]) elif role == "assistant": return AIMessage(content=_dict.get("content", "") or "") + elif role == "system": + return SystemMessage(content=_dict.get("content", "")) else: return ChatMessage(content=_dict["content"], role=role) @@ -70,6 +76,8 @@ def _convert_delta_to_message_chunk( return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: return AIMessageChunk(content=content) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) elif role or default_class == ChatMessageChunk: return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] else: @@ -113,7 +121,7 @@ class ChatBaichuan(BaseChatModel): def lc_serializable(self) -> bool: return True - baichuan_api_base: str = Field(default=DEFAULT_API_BASE) + baichuan_api_base: str = Field(default=DEFAULT_API_BASE, alias="base_url") """Baichuan custom endpoints""" baichuan_api_key: SecretStr = Field(alias="api_key") """Baichuan API Key""" @@ -121,6 +129,8 @@ class ChatBaichuan(BaseChatModel): """[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key""" streaming: bool = False """Whether to stream the results or not.""" + max_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" request_timeout: int = Field(default=60, alias="timeout") """request timeout for chat http requests""" model: str = "Baichuan2-Turbo-192K" @@ -133,7 +143,8 @@ class ChatBaichuan(BaseChatModel): top_p: float = 0.85 """What probability mass to use.""" with_search_enhance: bool = False - """Whether to use search enhance, default is False.""" + """[DEPRECATED, keeping it for for backward compatibility], + Whether to use search enhance, default is False.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for API call not explicitly specified.""" @@ -193,8 +204,8 @@ class ChatBaichuan(BaseChatModel): "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, - "with_search_enhance": self.with_search_enhance, "stream": self.streaming, + "max_tokens": self.max_tokens, } return {**normal_params, **self.model_kwargs} diff --git a/libs/community/tests/integration_tests/chat_models/test_baichuan.py b/libs/community/tests/integration_tests/chat_models/test_baichuan.py index 3ffcf2e9dda..74d8d62df53 100644 --- a/libs/community/tests/integration_tests/chat_models/test_baichuan.py +++ b/libs/community/tests/integration_tests/chat_models/test_baichuan.py @@ -76,3 +76,14 @@ async def test_chat_baichuan_astream() -> None: async for chunk in chat.astream("今天天气如何?"): assert isinstance(chunk, AIMessage) assert isinstance(chunk.content, str) + + +def test_chat_baichuan_with_system_role() -> None: + chat = ChatBaichuan() # type: ignore[call-arg] + messages = [ + ("system", "你是一名专业的翻译家,可以将用户的中文翻译为英文。"), + ("human", "我喜欢编程。"), + ] + response = chat.invoke(messages) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) diff --git a/libs/community/tests/unit_tests/chat_models/test_baichuan.py b/libs/community/tests/unit_tests/chat_models/test_baichuan.py index def23d56e5f..4dd44dd1811 100644 --- a/libs/community/tests/unit_tests/chat_models/test_baichuan.py +++ b/libs/community/tests/unit_tests/chat_models/test_baichuan.py @@ -4,7 +4,6 @@ import pytest from langchain_core.messages import ( AIMessage, AIMessageChunk, - ChatMessage, FunctionMessage, HumanMessage, HumanMessageChunk, @@ -54,9 +53,9 @@ def test__convert_message_to_dict_ai() -> None: def test__convert_message_to_dict_system() -> None: message = SystemMessage(content="foo") - with pytest.raises(TypeError) as e: - _convert_message_to_dict(message) - assert "Got unknown type" in str(e) + result = _convert_message_to_dict(message) + expected_output = {"role": "system", "content": "foo"} + assert result == expected_output def test__convert_message_to_dict_function() -> None: @@ -83,7 +82,7 @@ def test__convert_dict_to_message_ai() -> None: def test__convert_dict_to_message_other_role() -> None: message_dict = {"role": "system", "content": "foo"} result = _convert_dict_to_message(message_dict) - expected_output = ChatMessage(role="system", content="foo") + expected_output = SystemMessage(content="foo") assert result == expected_output @@ -134,3 +133,11 @@ def test_uses_actual_secret_value_from_secret_str() -> None: cast(SecretStr, chat.baichuan_secret_key).get_secret_value() == "test-secret-key" ) + + +def test_chat_baichuan_with_base_url() -> None: + chat = ChatBaichuan( # type: ignore[call-arg] + api_key="your-api-key", # type: ignore[arg-type] + base_url="https://exmaple.com", # type: ignore[arg-type] + ) + assert chat.baichuan_api_base == "https://exmaple.com"