mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
community[patch]: Improve ChatBaichuan init args and role (#23878)
- **Description:** Improve ChatBaichuan init args and role - ChatBaichuan adds `system` role - alias: `baichuan_api_base` -> `base_url` - `with_search_enhance` is deprecated - Add `max_tokens` argument
This commit is contained in:
parent
56cca23745
commit
9d97de34ae
@ -22,6 +22,8 @@ from langchain_core.messages import (
|
|||||||
ChatMessageChunk,
|
ChatMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
|
SystemMessage,
|
||||||
|
SystemMessageChunk,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
@ -44,6 +46,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Got unknown type {message}")
|
raise TypeError(f"Got unknown type {message}")
|
||||||
|
|
||||||
@ -56,6 +60,8 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
return HumanMessage(content=_dict["content"])
|
return HumanMessage(content=_dict["content"])
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
return AIMessage(content=_dict.get("content", "") or "")
|
return AIMessage(content=_dict.get("content", "") or "")
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=_dict.get("content", ""))
|
||||||
else:
|
else:
|
||||||
return ChatMessage(content=_dict["content"], role=role)
|
return ChatMessage(content=_dict["content"], role=role)
|
||||||
|
|
||||||
@ -70,6 +76,8 @@ def _convert_delta_to_message_chunk(
|
|||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
return AIMessageChunk(content=content)
|
return AIMessageChunk(content=content)
|
||||||
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
|
return SystemMessageChunk(content=content)
|
||||||
elif role or default_class == ChatMessageChunk:
|
elif role or default_class == ChatMessageChunk:
|
||||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
@ -113,7 +121,7 @@ class ChatBaichuan(BaseChatModel):
|
|||||||
def lc_serializable(self) -> bool:
|
def lc_serializable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
|
baichuan_api_base: str = Field(default=DEFAULT_API_BASE, alias="base_url")
|
||||||
"""Baichuan custom endpoints"""
|
"""Baichuan custom endpoints"""
|
||||||
baichuan_api_key: SecretStr = Field(alias="api_key")
|
baichuan_api_key: SecretStr = Field(alias="api_key")
|
||||||
"""Baichuan API Key"""
|
"""Baichuan API Key"""
|
||||||
@ -121,6 +129,8 @@ class ChatBaichuan(BaseChatModel):
|
|||||||
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
|
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
"""Whether to stream the results or not."""
|
"""Whether to stream the results or not."""
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
"""Maximum number of tokens to generate."""
|
||||||
request_timeout: int = Field(default=60, alias="timeout")
|
request_timeout: int = Field(default=60, alias="timeout")
|
||||||
"""request timeout for chat http requests"""
|
"""request timeout for chat http requests"""
|
||||||
model: str = "Baichuan2-Turbo-192K"
|
model: str = "Baichuan2-Turbo-192K"
|
||||||
@ -133,7 +143,8 @@ class ChatBaichuan(BaseChatModel):
|
|||||||
top_p: float = 0.85
|
top_p: float = 0.85
|
||||||
"""What probability mass to use."""
|
"""What probability mass to use."""
|
||||||
with_search_enhance: bool = False
|
with_search_enhance: bool = False
|
||||||
"""Whether to use search enhance, default is False."""
|
"""[DEPRECATED, keeping it for for backward compatibility],
|
||||||
|
Whether to use search enhance, default is False."""
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Holds any model parameters valid for API call not explicitly specified."""
|
"""Holds any model parameters valid for API call not explicitly specified."""
|
||||||
|
|
||||||
@ -193,8 +204,8 @@ class ChatBaichuan(BaseChatModel):
|
|||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"with_search_enhance": self.with_search_enhance,
|
|
||||||
"stream": self.streaming,
|
"stream": self.streaming,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {**normal_params, **self.model_kwargs}
|
return {**normal_params, **self.model_kwargs}
|
||||||
|
@ -76,3 +76,14 @@ async def test_chat_baichuan_astream() -> None:
|
|||||||
async for chunk in chat.astream("今天天气如何?"):
|
async for chunk in chat.astream("今天天气如何?"):
|
||||||
assert isinstance(chunk, AIMessage)
|
assert isinstance(chunk, AIMessage)
|
||||||
assert isinstance(chunk.content, str)
|
assert isinstance(chunk.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_baichuan_with_system_role() -> None:
|
||||||
|
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||||
|
messages = [
|
||||||
|
("system", "你是一名专业的翻译家,可以将用户的中文翻译为英文。"),
|
||||||
|
("human", "我喜欢编程。"),
|
||||||
|
]
|
||||||
|
response = chat.invoke(messages)
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
@ -4,7 +4,6 @@ import pytest
|
|||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
ChatMessage,
|
|
||||||
FunctionMessage,
|
FunctionMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
@ -54,9 +53,9 @@ def test__convert_message_to_dict_ai() -> None:
|
|||||||
|
|
||||||
def test__convert_message_to_dict_system() -> None:
|
def test__convert_message_to_dict_system() -> None:
|
||||||
message = SystemMessage(content="foo")
|
message = SystemMessage(content="foo")
|
||||||
with pytest.raises(TypeError) as e:
|
result = _convert_message_to_dict(message)
|
||||||
_convert_message_to_dict(message)
|
expected_output = {"role": "system", "content": "foo"}
|
||||||
assert "Got unknown type" in str(e)
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test__convert_message_to_dict_function() -> None:
|
def test__convert_message_to_dict_function() -> None:
|
||||||
@ -83,7 +82,7 @@ def test__convert_dict_to_message_ai() -> None:
|
|||||||
def test__convert_dict_to_message_other_role() -> None:
|
def test__convert_dict_to_message_other_role() -> None:
|
||||||
message_dict = {"role": "system", "content": "foo"}
|
message_dict = {"role": "system", "content": "foo"}
|
||||||
result = _convert_dict_to_message(message_dict)
|
result = _convert_dict_to_message(message_dict)
|
||||||
expected_output = ChatMessage(role="system", content="foo")
|
expected_output = SystemMessage(content="foo")
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
@ -134,3 +133,11 @@ def test_uses_actual_secret_value_from_secret_str() -> None:
|
|||||||
cast(SecretStr, chat.baichuan_secret_key).get_secret_value()
|
cast(SecretStr, chat.baichuan_secret_key).get_secret_value()
|
||||||
== "test-secret-key"
|
== "test-secret-key"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_baichuan_with_base_url() -> None:
|
||||||
|
chat = ChatBaichuan( # type: ignore[call-arg]
|
||||||
|
api_key="your-api-key", # type: ignore[arg-type]
|
||||||
|
base_url="https://exmaple.com", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
assert chat.baichuan_api_base == "https://exmaple.com"
|
||||||
|
Loading…
Reference in New Issue
Block a user