mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
mistralai[patch]: Update more @root_validators for pydantic 2 compatibility (#25446)
Update @root_validators in mistralai integration for pydantic 2 compatibility
This commit is contained in:
parent
6910b0b3aa
commit
2ef9d12372
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, Iterable, List, Optional
|
from typing import Dict, Iterable, List
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
@ -11,7 +11,9 @@ from langchain_core.pydantic_v1 import (
|
|||||||
SecretStr,
|
SecretStr,
|
||||||
root_validator,
|
root_validator,
|
||||||
)
|
)
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import (
|
||||||
|
secret_from_env,
|
||||||
|
)
|
||||||
from tokenizers import Tokenizer # type: ignore
|
from tokenizers import Tokenizer # type: ignore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -111,7 +113,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
client: httpx.Client = Field(default=None) #: :meta private:
|
client: httpx.Client = Field(default=None) #: :meta private:
|
||||||
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
|
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
|
||||||
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
mistral_api_key: SecretStr = Field(
|
||||||
|
alias="api_key",
|
||||||
|
default_factory=secret_from_env("MISTRAL_API_KEY", default=""),
|
||||||
|
)
|
||||||
endpoint: str = "https://api.mistral.ai/v1/"
|
endpoint: str = "https://api.mistral.ai/v1/"
|
||||||
max_retries: int = 5
|
max_retries: int = 5
|
||||||
timeout: int = 120
|
timeout: int = 120
|
||||||
@ -125,15 +130,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
allow_population_by_field_name = True
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate configuration."""
|
"""Validate configuration."""
|
||||||
|
|
||||||
values["mistral_api_key"] = convert_to_secret_str(
|
|
||||||
get_from_dict_or_env(
|
|
||||||
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
api_key_str = values["mistral_api_key"].get_secret_value()
|
||||||
# todo: handle retries
|
# todo: handle retries
|
||||||
if not values.get("client"):
|
if not values.get("client"):
|
||||||
|
Loading…
Reference in New Issue
Block a user