secretStr for baichuan chat model api key (#13946)

Merge pull request #13946
* secretStr for baichuan chat model api key
This commit is contained in:
nhywieza 2023-11-29 10:20:23 +08:00 committed by GitHub
parent aff1dba252
commit 9b86fb3fcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 6 deletions

View File

@ -98,7 +98,7 @@ class ChatBaichuan(BaseChatModel):
baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
"""Baichuan custom endpoints"""
baichuan_api_key: Optional[str] = None
baichuan_api_key: Optional[SecretStr] = None
"""Baichuan API Key"""
baichuan_secret_key: Optional[SecretStr] = None
"""Baichuan Secret Key"""
@ -159,10 +159,12 @@ class ChatBaichuan(BaseChatModel):
"BAICHUAN_API_BASE",
DEFAULT_API_BASE,
)
values["baichuan_api_key"] = get_from_dict_or_env(
values,
"baichuan_api_key",
"BAICHUAN_API_KEY",
values["baichuan_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"baichuan_api_key",
"BAICHUAN_API_KEY",
)
)
values["baichuan_secret_key"] = convert_to_secret_str(
get_from_dict_or_env(
@ -254,12 +256,16 @@ class ChatBaichuan(BaseChatModel):
url = f"{url}/stream"
url = f"{url}/chat"
api_key = ""
if self.baichuan_api_key:
api_key = self.baichuan_api_key.get_secret_value()
res = requests.post(
url=url,
timeout=self.request_timeout,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.baichuan_api_key}",
"Authorization": f"Bearer {api_key}",
"X-BC-Timestamp": str(timestamp),
"X-BC-Signature": _signature(
secret_key=self.baichuan_secret_key,

View File

@ -1,3 +1,5 @@
from typing import cast
import pytest
from langchain_core.messages import (
AIMessage,
@ -9,8 +11,10 @@ from langchain_core.messages import (
SystemMessage,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain.chat_models.baichuan import (
ChatBaichuan,
_convert_delta_to_message_chunk,
_convert_dict_to_message,
_convert_message_to_dict,
@ -97,3 +101,49 @@ def test__signature() -> None:
# https://platform.baichuan-ai.com/docs/api#4
expected_output = "24a50b2db1648e25a244c67c5ab57d3f"
assert result == expected_output
def test_baichuan_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key")
monkeypatch.setenv("BAICHUAN_SECRET_KEY", "test-secret-key")
chat = ChatBaichuan()
print(chat.baichuan_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
print(chat.baichuan_secret_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_baichuan_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
chat = ChatBaichuan(
baichuan_api_key="test-api-key", baichuan_secret_key="test-secret-key"
)
print(chat.baichuan_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
print(chat.baichuan_secret_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
chat = ChatBaichuan(
baichuan_api_key="test-api-key", baichuan_secret_key="test-secret-key"
)
assert cast(SecretStr, chat.baichuan_api_key).get_secret_value() == "test-api-key"
assert (
cast(SecretStr, chat.baichuan_secret_key).get_secret_value()
== "test-secret-key"
)