mistralai[patch]: Update root validator for compatibility with pydantic 2 (#25403)

This commit is contained in:
Eugene Yurtsev 2024-08-15 11:26:24 -04:00 committed by GitHub
parent 8afbab4cf6
commit 6f68c8d6ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -73,7 +73,7 @@ from langchain_core.pydantic_v1 import (
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import secret_from_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
@ -360,7 +360,10 @@ class ChatMistralAI(BaseChatModel):
client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
mistral_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
)
endpoint: str = "https://api.mistral.ai/v1"
max_retries: int = 5
timeout: int = 120
@ -465,15 +468,9 @@ class ChatMistralAI(BaseChatModel):
combined = {"token_usage": overall_token_usage, "model_name": self.model}
return combined
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, and top_p."""
values["mistral_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
)
api_key_str = values["mistral_api_key"].get_secret_value()
# todo: handle retries
if not values.get("client"):