mistralai[patch]: Update more @root_validators for pydantic 2 compatibility (#25446)

Update @root_validators in mistralai integration for pydantic 2 compatibility
This commit is contained in:
Eugene Yurtsev 2024-08-15 12:44:42 -04:00 committed by GitHub
parent 6910b0b3aa
commit 2ef9d12372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
import warnings import warnings
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List
import httpx import httpx
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
@ -11,7 +11,9 @@ from langchain_core.pydantic_v1 import (
SecretStr, SecretStr,
root_validator, root_validator,
) )
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import (
secret_from_env,
)
from tokenizers import Tokenizer # type: ignore from tokenizers import Tokenizer # type: ignore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -111,7 +113,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
client: httpx.Client = Field(default=None) #: :meta private: client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private: async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") mistral_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env("MISTRAL_API_KEY", default=""),
)
endpoint: str = "https://api.mistral.ai/v1/" endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5 max_retries: int = 5
timeout: int = 120 timeout: int = 120
@ -125,15 +130,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
arbitrary_types_allowed = True arbitrary_types_allowed = True
allow_population_by_field_name = True allow_population_by_field_name = True
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate configuration.""" """Validate configuration."""
values["mistral_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
)
api_key_str = values["mistral_api_key"].get_secret_value() api_key_str = values["mistral_api_key"].get_secret_value()
# todo: handle retries # todo: handle retries
if not values.get("client"): if not values.get("client"):