mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
premai[patch]:Standardized model init args (#21308)
[Standardized model init args #20085](https://github.com/langchain-ai/langchain/issues/20085) - Enable premai chat model to be initialized with `model_name` as an alias for `model`, `api_key` as an alias for `premai_api_key`. - Add initialization test `test_premai_initialization` --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -34,7 +34,13 @@ from langchain_core.messages import (
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -170,10 +176,10 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
project_id: int
|
||||
"""The project ID in which the experiments or deployments are carried out.
|
||||
You can find all your projects here: https://app.premai.io/projects/"""
|
||||
premai_api_key: Optional[SecretStr] = None
|
||||
premai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""
|
||||
|
||||
model: Optional[str] = None
|
||||
model: Optional[str] = Field(default=None, alias="model_name")
|
||||
"""Name of the model. This is an optional parameter.
|
||||
The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad
|
||||
If model name is other than default model then it will override the calls
|
||||
@@ -233,6 +239,8 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
|
@@ -1,5 +1,7 @@
|
||||
"""Test PremChat model"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
@@ -45,3 +47,13 @@ def test_messages_to_prompt_dict_with_valid_messages() -> None:
|
||||
|
||||
assert system_message == "System Prompt"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("premai")
|
||||
def test_premai_initialization() -> None:
|
||||
for model in [
|
||||
ChatPremAI(model="prem-ai-model", premai_api_key="xyz", project_id=8),
|
||||
ChatPremAI(model_name="prem-ai-model", api_key="xyz", project_id=8),
|
||||
]:
|
||||
assert model.model == "prem-ai-model"
|
||||
assert cast(SecretStr, model.premai_api_key).get_secret_value() == "xyz"
|
||||
|
Reference in New Issue
Block a user