diff --git a/libs/langchain/langchain/chat_models/baichuan.py b/libs/langchain/langchain/chat_models/baichuan.py index 88c95fc4c3f..df48f926b3c 100644 --- a/libs/langchain/langchain/chat_models/baichuan.py +++ b/libs/langchain/langchain/chat_models/baichuan.py @@ -98,7 +98,7 @@ class ChatBaichuan(BaseChatModel): baichuan_api_base: str = Field(default=DEFAULT_API_BASE) """Baichuan custom endpoints""" - baichuan_api_key: Optional[str] = None + baichuan_api_key: Optional[SecretStr] = None """Baichuan API Key""" baichuan_secret_key: Optional[SecretStr] = None """Baichuan Secret Key""" @@ -159,10 +159,12 @@ class ChatBaichuan(BaseChatModel): "BAICHUAN_API_BASE", DEFAULT_API_BASE, ) - values["baichuan_api_key"] = get_from_dict_or_env( - values, - "baichuan_api_key", - "BAICHUAN_API_KEY", + values["baichuan_api_key"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "baichuan_api_key", + "BAICHUAN_API_KEY", + ) ) values["baichuan_secret_key"] = convert_to_secret_str( get_from_dict_or_env( @@ -254,12 +256,16 @@ class ChatBaichuan(BaseChatModel): url = f"{url}/stream" url = f"{url}/chat" + api_key = "" + if self.baichuan_api_key: + api_key = self.baichuan_api_key.get_secret_value() + res = requests.post( url=url, timeout=self.request_timeout, headers={ "Content-Type": "application/json", - "Authorization": f"Bearer {self.baichuan_api_key}", + "Authorization": f"Bearer {api_key}", "X-BC-Timestamp": str(timestamp), "X-BC-Signature": _signature( secret_key=self.baichuan_secret_key, diff --git a/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py b/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py index 1168f8dbeb3..93c4cbb91da 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py @@ -1,3 +1,5 @@ +from typing import cast + import pytest from langchain_core.messages import ( AIMessage, @@ -9,8 +11,10 @@ from langchain_core.messages import ( SystemMessage, ) from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch from langchain.chat_models.baichuan import ( + ChatBaichuan, _convert_delta_to_message_chunk, _convert_dict_to_message, _convert_message_to_dict, @@ -97,3 +101,49 @@ def test__signature() -> None: # https://platform.baichuan-ai.com/docs/api#4 expected_output = "24a50b2db1648e25a244c67c5ab57d3f" assert result == expected_output + + +def test_baichuan_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key") + monkeypatch.setenv("BAICHUAN_SECRET_KEY", "test-secret-key") + + chat = ChatBaichuan() + print(chat.baichuan_api_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + print(chat.baichuan_secret_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + +def test_baichuan_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + chat = ChatBaichuan( + baichuan_api_key="test-api-key", baichuan_secret_key="test-secret-key" + ) + print(chat.baichuan_api_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + print(chat.baichuan_secret_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secret_str() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + chat = ChatBaichuan( + baichuan_api_key="test-api-key", baichuan_secret_key="test-secret-key" + ) + assert cast(SecretStr, chat.baichuan_api_key).get_secret_value() == "test-api-key" + assert ( + cast(SecretStr, chat.baichuan_secret_key).get_secret_value() + == "test-secret-key" + )