mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 08:27:03 +00:00
parent
9b86fb3fcb
commit
f97ab84c6b
@ -30,8 +30,8 @@ from langchain_core.messages import (
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_pydantic_field_names
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
@ -172,7 +172,7 @@ class JinaChat(BaseChatModel):
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
jinachat_api_key: Optional[str] = None
|
||||
jinachat_api_key: Optional[SecretStr] = None
|
||||
"""Base URL path for API requests,
|
||||
leave blank if not using a proxy or service emulator."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
@ -218,8 +218,8 @@ class JinaChat(BaseChatModel):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["jinachat_api_key"] = get_from_dict_or_env(
|
||||
values, "jinachat_api_key", "JINACHAT_API_KEY"
|
||||
values["jinachat_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "jinachat_api_key", "JINACHAT_API_KEY")
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
@ -395,7 +395,8 @@ class JinaChat(BaseChatModel):
|
||||
def _invocation_params(self) -> Mapping[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
jinachat_creds: Dict[str, Any] = {
|
||||
"api_key": self.jinachat_api_key,
|
||||
"api_key": self.jinachat_api_key
|
||||
and self.jinachat_api_key.get_secret_value(),
|
||||
"api_base": "https://api.chat.jina.ai/v1",
|
||||
"model": "jinachat",
|
||||
}
|
||||
|
@ -1,15 +1,52 @@
|
||||
"""Test JinaChat wrapper."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.jinachat import JinaChat
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_jinachat_api_key_is_secret_string() -> None:
|
||||
llm = JinaChat(jinachat_api_key="secret-api-key")
|
||||
assert isinstance(llm.jinachat_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_jinachat_api_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("JINACHAT_API_KEY", "secret-api-key")
|
||||
llm = JinaChat()
|
||||
print(llm.jinachat_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_jinachat_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via the initializer"""
|
||||
llm = JinaChat(jinachat_api_key="secret-api-key")
|
||||
print(llm.jinachat_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_uses_actual_secret_value_from_secretstr() -> None:
|
||||
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
||||
llm = JinaChat(jinachat_api_key="secret-api-key")
|
||||
assert cast(SecretStr, llm.jinachat_api_key).get_secret_value() == "secret-api-key"
|
||||
|
||||
|
||||
def test_jinachat() -> None:
|
||||
"""Test JinaChat wrapper."""
|
||||
chat = JinaChat(max_tokens=10)
|
||||
|
Loading…
Reference in New Issue
Block a user