mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 15:35:14 +00:00
premai[patch]:Standardized model init args (#21308)
[Standardized model init args #20085](https://github.com/langchain-ai/langchain/issues/20085) - Enable premai chat model to be initialized with `model_name` as an alias for `model`, `api_key` as an alias for `premai_api_key`. - Add initialization test `test_premai_initialization` --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -34,7 +34,13 @@ from langchain_core.messages import (
|
|||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Extra,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -170,10 +176,10 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
|||||||
project_id: int
|
project_id: int
|
||||||
"""The project ID in which the experiments or deployments are carried out.
|
"""The project ID in which the experiments or deployments are carried out.
|
||||||
You can find all your projects here: https://app.premai.io/projects/"""
|
You can find all your projects here: https://app.premai.io/projects/"""
|
||||||
premai_api_key: Optional[SecretStr] = None
|
premai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||||
"""Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""
|
"""Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""
|
||||||
|
|
||||||
model: Optional[str] = None
|
model: Optional[str] = Field(default=None, alias="model_name")
|
||||||
"""Name of the model. This is an optional parameter.
|
"""Name of the model. This is an optional parameter.
|
||||||
The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad
|
The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad
|
||||||
If model name is other than default model then it will override the calls
|
If model name is other than default model then it will override the calls
|
||||||
@@ -233,6 +239,8 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
|||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environments(cls, values: Dict) -> Dict:
|
def validate_environments(cls, values: Dict) -> Dict:
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
"""Test PremChat model"""
|
"""Test PremChat model"""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.pydantic_v1 import SecretStr
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
@@ -45,3 +47,13 @@ def test_messages_to_prompt_dict_with_valid_messages() -> None:
|
|||||||
|
|
||||||
assert system_message == "System Prompt"
|
assert system_message == "System Prompt"
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("premai")
|
||||||
|
def test_premai_initialization() -> None:
|
||||||
|
for model in [
|
||||||
|
ChatPremAI(model="prem-ai-model", premai_api_key="xyz", project_id=8),
|
||||||
|
ChatPremAI(model_name="prem-ai-model", api_key="xyz", project_id=8),
|
||||||
|
]:
|
||||||
|
assert model.model == "prem-ai-model"
|
||||||
|
assert cast(SecretStr, model.premai_api_key).get_secret_value() == "xyz"
|
||||||
|
Reference in New Issue
Block a user