premai[patch]:Standardized model init args (#21308)

[Standardized model init args
#20085](https://github.com/langchain-ai/langchain/issues/20085)
- Enable premai chat model to be initialized with `model_name` as an
alias for `model`, `api_key` as an alias for `premai_api_key`.
- Add initialization test `test_premai_initialization`

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
nrpd25
2024-05-06 15:12:29 -07:00
committed by GitHub
parent 6f17158606
commit 95cc8e3fc3
2 changed files with 23 additions and 3 deletions

View File

@@ -34,7 +34,13 @@ from langchain_core.messages import (
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import get_from_dict_or_env
if TYPE_CHECKING:
@@ -170,10 +176,10 @@ class ChatPremAI(BaseChatModel, BaseModel):
project_id: int
"""The project ID in which the experiments or deployments are carried out.
You can find all your projects here: https://app.premai.io/projects/"""
premai_api_key: Optional[SecretStr] = None
premai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""
model: Optional[str] = None
model: Optional[str] = Field(default=None, alias="model_name")
"""Name of the model. This is an optional parameter.
The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad
If model name is other than default model then it will override the calls
@@ -233,6 +239,8 @@ class ChatPremAI(BaseChatModel, BaseModel):
"""Configuration for this pydantic object."""
extra = Extra.forbid
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator()
def validate_environments(cls, values: Dict) -> Dict:

View File

@@ -1,5 +1,7 @@
"""Test PremChat model"""
from typing import cast
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import SecretStr
@@ -45,3 +47,13 @@ def test_messages_to_prompt_dict_with_valid_messages() -> None:
assert system_message == "System Prompt"
assert result == expected
@pytest.mark.requires("premai")
def test_premai_initialization() -> None:
for model in [
ChatPremAI(model="prem-ai-model", premai_api_key="xyz", project_id=8),
ChatPremAI(model_name="prem-ai-model", api_key="xyz", project_id=8),
]:
assert model.model == "prem-ai-model"
assert cast(SecretStr, model.premai_api_key).get_secret_value() == "xyz"