From 69f9acb60f04c6ea851971415a51d453dc41f0e6 Mon Sep 17 00:00:00 2001 From: Param Singh Date: Thu, 29 Aug 2024 08:01:28 -0700 Subject: [PATCH] premai[patch]: Standardize premai params (#21513) Thank you for contributing to LangChain! community:premai[patch]: standardize init args - updated `temperature` with Pydantic Field, updated the unit test. - updated `max_tokens` with Pydantic Field, updated the unit test. - updated `max_retries` with Pydantic Field, updated the unit test. Related to #20085 --------- Co-authored-by: Isaac Francisco <78627776+isahers1@users.noreply.github.com> Co-authored-by: ccurme --- .../langchain_community/chat_models/premai.py | 15 ++++++++++++--- .../tests/unit_tests/chat_models/test_premai.py | 3 +++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/chat_models/premai.py b/libs/community/langchain_community/chat_models/premai.py index 904496ca253..5498ac9e400 100644 --- a/libs/community/langchain_community/chat_models/premai.py +++ b/libs/community/langchain_community/chat_models/premai.py @@ -271,13 +271,22 @@ class ChatPremAI(BaseChatModel, BaseModel): If model name is other than default model then it will override the calls from the model deployed from launchpad.""" - temperature: Optional[float] = None + session_id: Optional[str] = None + """The ID of the session to use. It helps to track the chat history.""" + + temperature: Optional[float] = Field(default=None) """Model temperature. Value should be >= 0 and <= 1.0""" - max_tokens: Optional[int] = None + top_p: Optional[float] = None + """top_p adjusts the number of choices for each predicted tokens based on + cumulative probabilities. Value should be ranging between 0.0 and 1.0. + """ + + max_tokens: Optional[int] = Field(default=None) + """The maximum number of tokens to generate""" - max_retries: int = 1 + max_retries: int = Field(default=1) """Max number of retries to call the API""" system_prompt: Optional[str] = "" diff --git a/libs/community/tests/unit_tests/chat_models/test_premai.py b/libs/community/tests/unit_tests/chat_models/test_premai.py index b4275fca47b..22327e3aa2f 100644 --- a/libs/community/tests/unit_tests/chat_models/test_premai.py +++ b/libs/community/tests/unit_tests/chat_models/test_premai.py @@ -67,4 +67,7 @@ def test_premai_initialization() -> None: ChatPremAI(model_name="prem-ai-model", api_key="xyz", project_id=8), # type: ignore[arg-type, call-arg] ]: assert model.model == "prem-ai-model" + assert model.temperature is None + assert model.max_tokens is None + assert model.max_retries == 1 assert cast(SecretStr, model.premai_api_key).get_secret_value() == "xyz"