premai[patch]: Standardize premai params (#21513)

Thank you for contributing to LangChain!

community:premai[patch]: standardize init args

- updated `temperature` with Pydantic Field, updated the unit test.
- updated `max_tokens` with Pydantic Field, updated the unit test.
- updated `max_retries` with Pydantic Field, updated the unit test.

Related to #20085

---------

Co-authored-by: Isaac Francisco <78627776+isahers1@users.noreply.github.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Param Singh 2024-08-29 08:01:28 -07:00 committed by GitHub
parent fcf9230257
commit 69f9acb60f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 3 deletions

View File

@ -271,13 +271,22 @@ class ChatPremAI(BaseChatModel, BaseModel):
If model name is other than default model then it will override the calls If model name is other than default model then it will override the calls
from the model deployed from launchpad.""" from the model deployed from launchpad."""
temperature: Optional[float] = None session_id: Optional[str] = None
"""The ID of the session to use. It helps to track the chat history."""
temperature: Optional[float] = Field(default=None)
"""Model temperature. Value should be >= 0 and <= 1.0""" """Model temperature. Value should be >= 0 and <= 1.0"""
max_tokens: Optional[int] = None top_p: Optional[float] = None
"""top_p adjusts the number of choices for each predicted tokens based on
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
"""
max_tokens: Optional[int] = Field(default=None)
"""The maximum number of tokens to generate""" """The maximum number of tokens to generate"""
max_retries: int = 1 max_retries: int = Field(default=1)
"""Max number of retries to call the API""" """Max number of retries to call the API"""
system_prompt: Optional[str] = "" system_prompt: Optional[str] = ""

View File

@ -67,4 +67,7 @@ def test_premai_initialization() -> None:
ChatPremAI(model_name="prem-ai-model", api_key="xyz", project_id=8), # type: ignore[arg-type, call-arg] ChatPremAI(model_name="prem-ai-model", api_key="xyz", project_id=8), # type: ignore[arg-type, call-arg]
]: ]:
assert model.model == "prem-ai-model" assert model.model == "prem-ai-model"
assert model.temperature is None
assert model.max_tokens is None
assert model.max_retries == 1
assert cast(SecretStr, model.premai_api_key).get_secret_value() == "xyz" assert cast(SecretStr, model.premai_api_key).get_secret_value() == "xyz"