premai[patch]:Standardized model init args (#21308)

[Standardized model init args
#20085](https://github.com/langchain-ai/langchain/issues/20085)
- Enable premai chat model to be initialized with `model_name` as an
alias for `model`, `api_key` as an alias for `premai_api_key`.
- Add initialization test `test_premai_initialization`

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
nrpd25
2024-05-06 15:12:29 -07:00
committed by GitHub
parent 6f17158606
commit 95cc8e3fc3
2 changed files with 23 additions and 3 deletions

View File

@@ -34,7 +34,13 @@ from langchain_core.messages import (
SystemMessageChunk, SystemMessageChunk,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import get_from_dict_or_env
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -170,10 +176,10 @@ class ChatPremAI(BaseChatModel, BaseModel):
project_id: int project_id: int
"""The project ID in which the experiments or deployments are carried out. """The project ID in which the experiments or deployments are carried out.
You can find all your projects here: https://app.premai.io/projects/""" You can find all your projects here: https://app.premai.io/projects/"""
premai_api_key: Optional[SecretStr] = None premai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Prem AI API Key. Get it here: https://app.premai.io/api_keys/""" """Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""
model: Optional[str] = None model: Optional[str] = Field(default=None, alias="model_name")
"""Name of the model. This is an optional parameter. """Name of the model. This is an optional parameter.
The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad
If model name is other than default model then it will override the calls If model name is other than default model then it will override the calls
@@ -233,6 +239,8 @@ class ChatPremAI(BaseChatModel, BaseModel):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid extra = Extra.forbid
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator() @root_validator()
def validate_environments(cls, values: Dict) -> Dict: def validate_environments(cls, values: Dict) -> Dict:

View File

@@ -1,5 +1,7 @@
"""Test PremChat model""" """Test PremChat model"""
from typing import cast
import pytest import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import SecretStr from langchain_core.pydantic_v1 import SecretStr
@@ -45,3 +47,13 @@ def test_messages_to_prompt_dict_with_valid_messages() -> None:
assert system_message == "System Prompt" assert system_message == "System Prompt"
assert result == expected assert result == expected
@pytest.mark.requires("premai")
def test_premai_initialization() -> None:
for model in [
ChatPremAI(model="prem-ai-model", premai_api_key="xyz", project_id=8),
ChatPremAI(model_name="prem-ai-model", api_key="xyz", project_id=8),
]:
assert model.model == "prem-ai-model"
assert cast(SecretStr, model.premai_api_key).get_secret_value() == "xyz"