community[patch]: Improve ChatBaichuan init args and role (#23878)

- **Description:** Improve ChatBaichuan init args and role
   -  ChatBaichuan adds `system` role
   - alias: `baichuan_api_base` -> `base_url`
   - `with_search_enhance` is deprecated
   - Add `max_tokens` argument
This commit is contained in:
maang-h
2024-07-16 03:17:00 +08:00
committed by GitHub
parent 56cca23745
commit 9d97de34ae
3 changed files with 37 additions and 8 deletions

View File

@@ -76,3 +76,14 @@ async def test_chat_baichuan_astream() -> None:
async for chunk in chat.astream("今天天气如何?"):
assert isinstance(chunk, AIMessage)
assert isinstance(chunk.content, str)
def test_chat_baichuan_with_system_role() -> None:
chat = ChatBaichuan() # type: ignore[call-arg]
messages = [
("system", "你是一名专业的翻译家,可以将用户的中文翻译为英文。"),
("human", "我喜欢编程。"),
]
response = chat.invoke(messages)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)

View File

@@ -4,7 +4,6 @@ import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
@@ -54,9 +53,9 @@ def test__convert_message_to_dict_ai() -> None:
def test__convert_message_to_dict_system() -> None:
message = SystemMessage(content="foo")
with pytest.raises(TypeError) as e:
_convert_message_to_dict(message)
assert "Got unknown type" in str(e)
result = _convert_message_to_dict(message)
expected_output = {"role": "system", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_function() -> None:
@@ -83,7 +82,7 @@ def test__convert_dict_to_message_ai() -> None:
def test__convert_dict_to_message_other_role() -> None:
message_dict = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = ChatMessage(role="system", content="foo")
expected_output = SystemMessage(content="foo")
assert result == expected_output
@@ -134,3 +133,11 @@ def test_uses_actual_secret_value_from_secret_str() -> None:
cast(SecretStr, chat.baichuan_secret_key).get_secret_value()
== "test-secret-key"
)
def test_chat_baichuan_with_base_url() -> None:
chat = ChatBaichuan( # type: ignore[call-arg]
api_key="your-api-key", # type: ignore[arg-type]
base_url="https://exmaple.com", # type: ignore[arg-type]
)
assert chat.baichuan_api_base == "https://exmaple.com"