From 95cc8e3fc32f3f6396f1cf1751d86e7958ec120e Mon Sep 17 00:00:00 2001 From: nrpd25 <36529906+Narapady@users.noreply.github.com> Date: Mon, 6 May 2024 15:12:29 -0700 Subject: [PATCH] premai[patch]:Standardized model init args (#21308) [Standardized model init args #20085](https://github.com/langchain-ai/langchain/issues/20085) - Enable premai chat model to be initialized with `model_name` as an alias for `model`, `api_key` as an alias for `premai_api_key`. - Add initialization test `test_premai_initialization` --------- Co-authored-by: Chester Curme --- .../langchain_community/chat_models/premai.py | 14 +++++++++++--- .../tests/unit_tests/chat_models/test_premai.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/chat_models/premai.py b/libs/community/langchain_community/chat_models/premai.py index b509b8bf7db..5991506cd6b 100644 --- a/libs/community/langchain_community/chat_models/premai.py +++ b/libs/community/langchain_community/chat_models/premai.py @@ -34,7 +34,13 @@ from langchain_core.messages import ( SystemMessageChunk, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator +from langchain_core.pydantic_v1 import ( + BaseModel, + Extra, + Field, + SecretStr, + root_validator, +) from langchain_core.utils import get_from_dict_or_env if TYPE_CHECKING: @@ -170,10 +176,10 @@ class ChatPremAI(BaseChatModel, BaseModel): project_id: int """The project ID in which the experiments or deployments are carried out. You can find all your projects here: https://app.premai.io/projects/""" - premai_api_key: Optional[SecretStr] = None + premai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") """Prem AI API Key. Get it here: https://app.premai.io/api_keys/""" - model: Optional[str] = None + model: Optional[str] = Field(default=None, alias="model_name") """Name of the model. This is an optional parameter. The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad If model name is other than default model then it will override the calls @@ -233,6 +239,8 @@ class ChatPremAI(BaseChatModel, BaseModel): """Configuration for this pydantic object.""" extra = Extra.forbid + allow_population_by_field_name = True + arbitrary_types_allowed = True @root_validator() def validate_environments(cls, values: Dict) -> Dict: diff --git a/libs/community/tests/unit_tests/chat_models/test_premai.py b/libs/community/tests/unit_tests/chat_models/test_premai.py index c72d4f0ec86..ac5299f76f3 100644 --- a/libs/community/tests/unit_tests/chat_models/test_premai.py +++ b/libs/community/tests/unit_tests/chat_models/test_premai.py @@ -1,5 +1,7 @@ """Test PremChat model""" +from typing import cast + import pytest from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.pydantic_v1 import SecretStr @@ -45,3 +47,13 @@ def test_messages_to_prompt_dict_with_valid_messages() -> None: assert system_message == "System Prompt" assert result == expected + + +@pytest.mark.requires("premai") +def test_premai_initialization() -> None: + for model in [ + ChatPremAI(model="prem-ai-model", premai_api_key="xyz", project_id=8), + ChatPremAI(model_name="prem-ai-model", api_key="xyz", project_id=8), + ]: + assert model.model == "prem-ai-model" + assert cast(SecretStr, model.premai_api_key).get_secret_value() == "xyz"