Merge pull request #13907

* feat: mask api_key for jina
This commit is contained in:
chyroc 2023-11-29 10:24:50 +08:00 committed by GitHub
parent 9b86fb3fcb
commit f97ab84c6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 6 deletions

View File

@ -30,8 +30,8 @@ from langchain_core.messages import (
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_pydantic_field_names
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_pydantic_field_names
from tenacity import (
before_sleep_log,
retry,
@ -172,7 +172,7 @@ class JinaChat(BaseChatModel):
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
jinachat_api_key: Optional[str] = None
jinachat_api_key: Optional[SecretStr] = None
"""Base URL path for API requests,
leave blank if not using a proxy or service emulator."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
@ -218,8 +218,8 @@ class JinaChat(BaseChatModel):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["jinachat_api_key"] = get_from_dict_or_env(
values, "jinachat_api_key", "JINACHAT_API_KEY"
values["jinachat_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "jinachat_api_key", "JINACHAT_API_KEY")
)
try:
import openai
@ -395,7 +395,8 @@ class JinaChat(BaseChatModel):
def _invocation_params(self) -> Mapping[str, Any]:
"""Get the parameters used to invoke the model."""
jinachat_creds: Dict[str, Any] = {
"api_key": self.jinachat_api_key,
"api_key": self.jinachat_api_key
and self.jinachat_api_key.get_secret_value(),
"api_base": "https://api.chat.jina.ai/v1",
"model": "jinachat",
}

View File

@ -1,15 +1,52 @@
"""Test JinaChat wrapper."""
from typing import cast
import pytest
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.jinachat import JinaChat
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_jinachat_api_key_is_secret_string() -> None:
llm = JinaChat(jinachat_api_key="secret-api-key")
assert isinstance(llm.jinachat_api_key, SecretStr)
def test_jinachat_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("JINACHAT_API_KEY", "secret-api-key")
llm = JinaChat()
print(llm.jinachat_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_jinachat_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = JinaChat(jinachat_api_key="secret-api-key")
print(llm.jinachat_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secretstr() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
llm = JinaChat(jinachat_api_key="secret-api-key")
assert cast(SecretStr, llm.jinachat_api_key).get_secret_value() == "secret-api-key"
def test_jinachat() -> None:
"""Test JinaChat wrapper."""
chat = JinaChat(max_tokens=10)