community[patch]: Improve ChatBaichuan init args and role (#23878)

- **Description:** Improve ChatBaichuan init args and role
   -  ChatBaichuan adds `system` role
   - alias: `baichuan_api_base` -> `base_url`
   - `with_search_enhance` is deprecated
   - Add `max_tokens` argument
This commit is contained in:
maang-h 2024-07-16 03:17:00 +08:00 committed by GitHub
parent 56cca23745
commit 9d97de34ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 8 deletions

View File

@ -22,6 +22,8 @@ from langchain_core.messages import (
ChatMessageChunk, ChatMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk, HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
@ -44,6 +46,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"role": "user", "content": message.content} message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else: else:
raise TypeError(f"Got unknown type {message}") raise TypeError(f"Got unknown type {message}")
@ -56,6 +60,8 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
return HumanMessage(content=_dict["content"]) return HumanMessage(content=_dict["content"])
elif role == "assistant": elif role == "assistant":
return AIMessage(content=_dict.get("content", "") or "") return AIMessage(content=_dict.get("content", "") or "")
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
else: else:
return ChatMessage(content=_dict["content"], role=role) return ChatMessage(content=_dict["content"], role=role)
@ -70,6 +76,8 @@ def _convert_delta_to_message_chunk(
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk: elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content) return AIMessageChunk(content=content)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk: elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
else: else:
@ -113,7 +121,7 @@ class ChatBaichuan(BaseChatModel):
def lc_serializable(self) -> bool: def lc_serializable(self) -> bool:
return True return True
baichuan_api_base: str = Field(default=DEFAULT_API_BASE) baichuan_api_base: str = Field(default=DEFAULT_API_BASE, alias="base_url")
"""Baichuan custom endpoints""" """Baichuan custom endpoints"""
baichuan_api_key: SecretStr = Field(alias="api_key") baichuan_api_key: SecretStr = Field(alias="api_key")
"""Baichuan API Key""" """Baichuan API Key"""
@ -121,6 +129,8 @@ class ChatBaichuan(BaseChatModel):
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key""" """[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
streaming: bool = False streaming: bool = False
"""Whether to stream the results or not.""" """Whether to stream the results or not."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
request_timeout: int = Field(default=60, alias="timeout") request_timeout: int = Field(default=60, alias="timeout")
"""request timeout for chat http requests""" """request timeout for chat http requests"""
model: str = "Baichuan2-Turbo-192K" model: str = "Baichuan2-Turbo-192K"
@ -133,7 +143,8 @@ class ChatBaichuan(BaseChatModel):
top_p: float = 0.85 top_p: float = 0.85
"""What probability mass to use.""" """What probability mass to use."""
with_search_enhance: bool = False with_search_enhance: bool = False
"""Whether to use search enhance, default is False.""" """[DEPRECATED, keeping it for for backward compatibility],
Whether to use search enhance, default is False."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for API call not explicitly specified.""" """Holds any model parameters valid for API call not explicitly specified."""
@ -193,8 +204,8 @@ class ChatBaichuan(BaseChatModel):
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
"top_k": self.top_k, "top_k": self.top_k,
"with_search_enhance": self.with_search_enhance,
"stream": self.streaming, "stream": self.streaming,
"max_tokens": self.max_tokens,
} }
return {**normal_params, **self.model_kwargs} return {**normal_params, **self.model_kwargs}

View File

@ -76,3 +76,14 @@ async def test_chat_baichuan_astream() -> None:
async for chunk in chat.astream("今天天气如何?"): async for chunk in chat.astream("今天天气如何?"):
assert isinstance(chunk, AIMessage) assert isinstance(chunk, AIMessage)
assert isinstance(chunk.content, str) assert isinstance(chunk.content, str)
def test_chat_baichuan_with_system_role() -> None:
chat = ChatBaichuan() # type: ignore[call-arg]
messages = [
("system", "你是一名专业的翻译家,可以将用户的中文翻译为英文。"),
("human", "我喜欢编程。"),
]
response = chat.invoke(messages)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)

View File

@ -4,7 +4,6 @@ import pytest
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
ChatMessage,
FunctionMessage, FunctionMessage,
HumanMessage, HumanMessage,
HumanMessageChunk, HumanMessageChunk,
@ -54,9 +53,9 @@ def test__convert_message_to_dict_ai() -> None:
def test__convert_message_to_dict_system() -> None: def test__convert_message_to_dict_system() -> None:
message = SystemMessage(content="foo") message = SystemMessage(content="foo")
with pytest.raises(TypeError) as e: result = _convert_message_to_dict(message)
_convert_message_to_dict(message) expected_output = {"role": "system", "content": "foo"}
assert "Got unknown type" in str(e) assert result == expected_output
def test__convert_message_to_dict_function() -> None: def test__convert_message_to_dict_function() -> None:
@ -83,7 +82,7 @@ def test__convert_dict_to_message_ai() -> None:
def test__convert_dict_to_message_other_role() -> None: def test__convert_dict_to_message_other_role() -> None:
message_dict = {"role": "system", "content": "foo"} message_dict = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message_dict) result = _convert_dict_to_message(message_dict)
expected_output = ChatMessage(role="system", content="foo") expected_output = SystemMessage(content="foo")
assert result == expected_output assert result == expected_output
@ -134,3 +133,11 @@ def test_uses_actual_secret_value_from_secret_str() -> None:
cast(SecretStr, chat.baichuan_secret_key).get_secret_value() cast(SecretStr, chat.baichuan_secret_key).get_secret_value()
== "test-secret-key" == "test-secret-key"
) )
def test_chat_baichuan_with_base_url() -> None:
chat = ChatBaichuan( # type: ignore[call-arg]
api_key="your-api-key", # type: ignore[arg-type]
base_url="https://exmaple.com", # type: ignore[arg-type]
)
assert chat.baichuan_api_base == "https://exmaple.com"