premai[patch]:Standardized model init args (#21308)

[Standardized model init args
#20085](https://github.com/langchain-ai/langchain/issues/20085)
- Enable premai chat model to be initialized with `model_name` as an
alias for `model`, `api_key` as an alias for `premai_api_key`.
- Add initialization test `test_premai_initialization`

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
nrpd25
2024-05-06 15:12:29 -07:00
committed by GitHub
parent 6f17158606
commit 95cc8e3fc3
2 changed files with 23 additions and 3 deletions

View File

@@ -34,7 +34,13 @@ from langchain_core.messages import (
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import get_from_dict_or_env
if TYPE_CHECKING:
@@ -170,10 +176,10 @@ class ChatPremAI(BaseChatModel, BaseModel):
project_id: int
"""The project ID in which the experiments or deployments are carried out.
You can find all your projects here: https://app.premai.io/projects/"""
premai_api_key: Optional[SecretStr] = None
premai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""
model: Optional[str] = None
model: Optional[str] = Field(default=None, alias="model_name")
"""Name of the model. This is an optional parameter.
The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad
If model name is other than default model then it will override the calls
@@ -233,6 +239,8 @@ class ChatPremAI(BaseChatModel, BaseModel):
"""Configuration for this pydantic object."""
extra = Extra.forbid
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator()
def validate_environments(cls, values: Dict) -> Dict: