From 2ef9d12372786337fd32b18b69cdb91bf8fbd539 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 15 Aug 2024 12:44:42 -0400 Subject: [PATCH] mistralai[patch]: Update more @root_validators for pydantic 2 compatibility (#25446) Update @root_validators in mistralai integration for pydantic 2 compatibility --- .../langchain_mistralai/embeddings.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index dea5f56f398..ea19fdd9cae 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -1,7 +1,7 @@ import asyncio import logging import warnings -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List import httpx from langchain_core.embeddings import Embeddings @@ -11,7 +11,9 @@ from langchain_core.pydantic_v1 import ( SecretStr, root_validator, ) -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + secret_from_env, +) from tokenizers import Tokenizer # type: ignore logger = logging.getLogger(__name__) @@ -111,7 +113,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings): client: httpx.Client = Field(default=None) #: :meta private: async_client: httpx.AsyncClient = Field(default=None) #: :meta private: - mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") + mistral_api_key: SecretStr = Field( + alias="api_key", + default_factory=secret_from_env("MISTRAL_API_KEY", default=""), + ) endpoint: str = "https://api.mistral.ai/v1/" max_retries: int = 5 timeout: int = 120 @@ -125,15 +130,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings): arbitrary_types_allowed = True allow_population_by_field_name = True - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate configuration.""" - values["mistral_api_key"] = convert_to_secret_str( - get_from_dict_or_env( - values, "mistral_api_key", "MISTRAL_API_KEY", default="" - ) - ) api_key_str = values["mistral_api_key"].get_secret_value() # todo: handle retries if not values.get("client"):