diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 31739853f0d..804bb64f340 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -73,7 +73,7 @@ from langchain_core.pydantic_v1 import ( ) from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import secret_from_env from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.pydantic import is_basemodel_subclass @@ -360,7 +360,10 @@ class ChatMistralAI(BaseChatModel): client: httpx.Client = Field(default=None) #: :meta private: async_client: httpx.AsyncClient = Field(default=None) #: :meta private: - mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") + mistral_api_key: Optional[SecretStr] = Field( + alias="api_key", + default_factory=secret_from_env("MISTRAL_API_KEY", default=None), + ) endpoint: str = "https://api.mistral.ai/v1" max_retries: int = 5 timeout: int = 120 @@ -465,15 +468,9 @@ class ChatMistralAI(BaseChatModel): combined = {"token_usage": overall_token_usage, "model_name": self.model} return combined - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists, temperature, and top_p.""" - - values["mistral_api_key"] = convert_to_secret_str( - get_from_dict_or_env( - values, "mistral_api_key", "MISTRAL_API_KEY", default="" - ) - ) api_key_str = values["mistral_api_key"].get_secret_value() # todo: handle retries if not values.get("client"):