mistralai[patch]: Update more @root_validators for pydantic 2 compatibility (#25446)

Update @root_validators in mistralai integration for pydantic 2 compatibility
This commit is contained in:
Eugene Yurtsev 2024-08-15 12:44:42 -04:00 committed by GitHub
parent 6910b0b3aa
commit 2ef9d12372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
import asyncio
import logging
import warnings
from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List
import httpx
from langchain_core.embeddings import Embeddings
@ -11,7 +11,9 @@ from langchain_core.pydantic_v1 import (
SecretStr,
root_validator,
)
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
secret_from_env,
)
from tokenizers import Tokenizer # type: ignore
logger = logging.getLogger(__name__)
@ -111,7 +113,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
mistral_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env("MISTRAL_API_KEY", default=""),
)
endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5
timeout: int = 120
@ -125,15 +130,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
arbitrary_types_allowed = True
allow_population_by_field_name = True
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate configuration."""
values["mistral_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
)
api_key_str = values["mistral_api_key"].get_secret_value()
# todo: handle retries
if not values.get("client"):