baichuan[patch]:standardize chat init args (#21298)

Thank you for contributing to LangChain!

community:baichuan[patch]: standardize init args

updated `baichuan_api_key` so that aliased to `api_key`. Added test that
it continues to set the same underlying attribute. Test checks for
`SecretStr`

updated `temperature` with Pydantic Field, added unit test. 

Related to https://github.com/langchain-ai/langchain/issues/20085
This commit is contained in:
Param Singh 2024-05-06 11:33:57 -07:00 committed by GitHub
parent 62559b20b3
commit fee91d43b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 5 deletions

View File

@ -89,7 +89,7 @@ class ChatBaichuan(BaseChatModel):
baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
"""Baichuan custom endpoints"""
baichuan_api_key: Optional[SecretStr] = None
baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Baichuan API Key"""
baichuan_secret_key: Optional[SecretStr] = None
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
@ -100,7 +100,7 @@ class ChatBaichuan(BaseChatModel):
model = "Baichuan2-Turbo-192K"
"""model name of Baichuan, default is `Baichuan2-Turbo-192K`,
other options include `Baichuan2-Turbo`"""
temperature: float = 0.3
temperature: Optional[float] = Field(default=0.3)
"""What sampling temperature to use."""
top_k: int = 5
"""What search sampling control to use."""

View File

@ -25,9 +25,7 @@ def test_initialization() -> None:
"""Test chat model initialization."""
for model in [
ChatBaichuan(
model="Baichuan2-Turbo-192K", baichuan_api_key="test-api-key", timeout=40
),
ChatBaichuan(model="Baichuan2-Turbo-192K", api_key="test-api-key", timeout=40),
ChatBaichuan(
model="Baichuan2-Turbo-192K",
baichuan_api_key="test-api-key",
@ -35,7 +33,9 @@ def test_initialization() -> None:
),
]:
assert model.model == "Baichuan2-Turbo-192K"
assert isinstance(model.baichuan_api_key, SecretStr)
assert model.request_timeout == 40
assert model.temperature == 0.3
def test__convert_message_to_dict_human() -> None: