From d948783a4c9c24badee436ed76d97672aa476bad Mon Sep 17 00:00:00 2001 From: MSubik <146082415+MSubik@users.noreply.github.com> Date: Thu, 23 May 2024 03:17:28 +0530 Subject: [PATCH] community[patch]: standardize init args, update for javelin sdk release. (#21980) Related to [20085](https://github.com/langchain-ai/langchain/issues/20085) Updated the Javelin chat model to standardize the initialization argument. Also fixed an existing bug, where code was initialized with incorrect call to the JavelinClient defined in the javelin_sdk, resulting in an initialization error. See related [Javelin Documentation](https://docs.getjavelin.io/docs/javelin-python/quickstart). --- .../chat_models/javelin_ai_gateway.py | 9 +++++++-- .../chat_models/test_javelin_ai_gateway.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/libs/community/langchain_community/chat_models/javelin_ai_gateway.py b/libs/community/langchain_community/chat_models/javelin_ai_gateway.py index 6c90103fae9..c45b034e493 100644 --- a/libs/community/langchain_community/chat_models/javelin_ai_gateway.py +++ b/libs/community/langchain_community/chat_models/javelin_ai_gateway.py @@ -18,7 +18,7 @@ from langchain_core.outputs import ( ChatGeneration, ChatResult, ) -from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, SecretStr logger = logging.getLogger(__name__) @@ -65,9 +65,14 @@ class ChatJavelinAIGateway(BaseChatModel): client: Any """javelin client.""" - javelin_api_key: Optional[SecretStr] = None + javelin_api_key: Optional[SecretStr] = Field(None, alias="api_key") """The API key for the Javelin AI Gateway.""" + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + def __init__(self, **kwargs: Any): try: from javelin_sdk import ( diff --git a/libs/community/tests/unit_tests/chat_models/test_javelin_ai_gateway.py b/libs/community/tests/unit_tests/chat_models/test_javelin_ai_gateway.py index 7c4500d340d..c612747dd5d 100644 --- a/libs/community/tests/unit_tests/chat_models/test_javelin_ai_gateway.py +++ b/libs/community/tests/unit_tests/chat_models/test_javelin_ai_gateway.py @@ -30,3 +30,17 @@ def test_api_key_masked_when_passed_via_constructor() -> None: assert str(llm.javelin_api_key) == "**********" assert "secret-api-key" not in repr(llm.javelin_api_key) assert "secret-api-key" not in repr(llm) + + +@pytest.mark.requires("javelin_sdk") +def test_api_key_alias() -> None: + for model in [ + ChatJavelinAIGateway( + route="", + javelin_api_key="secret-api-key", + ), + ChatJavelinAIGateway( + route="", api_key="secret-api-key" + ), + ]: + assert str(model.javelin_api_key) == "**********"