diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 06ee230d1f6..ac00c5daeaa 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -66,17 +66,19 @@ from langchain_core.output_parsers.openai_tools import ( parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import ( - BaseModel, - Field, - SecretStr, - root_validator, -) from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import secret_from_env from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.pydantic import is_basemodel_subclass +from pydantic import ( + BaseModel, + ConfigDict, + Field, + SecretStr, + model_validator, +) +from typing_extensions import Self logger = logging.getLogger(__name__) @@ -379,11 +381,10 @@ class ChatMistralAI(BaseChatModel): safe_mode: bool = False streaming: bool = False - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True - arbitrary_types_allowed = True + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + ) @property def _default_params(self) -> Dict[str, Any]: @@ -469,47 +470,50 @@ class ChatMistralAI(BaseChatModel): combined = {"token_usage": overall_token_usage, "model_name": self.model} return combined - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate api key, python package exists, temperature, and top_p.""" - api_key_str = values["mistral_api_key"].get_secret_value() + if isinstance(self.mistral_api_key, SecretStr): + api_key_str: Optional[str] = self.mistral_api_key.get_secret_value() + else: + api_key_str = self.mistral_api_key # todo: handle retries base_url_str = ( - values.get("endpoint") + self.endpoint or os.environ.get("MISTRAL_BASE_URL") or "https://api.mistral.ai/v1" ) - values["endpoint"] = base_url_str - if not values.get("client"): - values["client"] = httpx.Client( + self.endpoint = base_url_str + if not self.client: + self.client = httpx.Client( base_url=base_url_str, headers={ "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {api_key_str}", }, - timeout=values["timeout"], + timeout=self.timeout, ) # todo: handle retries and max_concurrency - if not values.get("async_client"): - values["async_client"] = httpx.AsyncClient( + if not self.async_client: + self.async_client = httpx.AsyncClient( base_url=base_url_str, headers={ "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {api_key_str}", }, - timeout=values["timeout"], + timeout=self.timeout, ) - if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: + if self.temperature is not None and not 0 <= self.temperature <= 1: raise ValueError("temperature must be in the range [0.0, 1.0]") - if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: + if self.top_p is not None and not 0 <= self.top_p <= 1: raise ValueError("top_p must be in the range [0.0, 1.0]") - return values + return self def _generate( self, @@ -728,7 +732,7 @@ class ChatMistralAI(BaseChatModel): from typing import Optional from langchain_mistralai import ChatMistralAI - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class AnswerWithJustification(BaseModel): @@ -759,7 +763,7 @@ class ChatMistralAI(BaseChatModel): .. code-block:: python from langchain_mistralai import ChatMistralAI - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel class AnswerWithJustification(BaseModel): @@ -846,7 +850,7 @@ class ChatMistralAI(BaseChatModel): .. code-block:: from langchain_mistralai import ChatMistralAI - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel class AnswerWithJustification(BaseModel): answer: str diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index ea19fdd9cae..485e9771c48 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -1,20 +1,22 @@ import asyncio import logging import warnings -from typing import Dict, Iterable, List +from typing import Iterable, List import httpx from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import ( - BaseModel, - Field, - SecretStr, - root_validator, -) from langchain_core.utils import ( secret_from_env, ) +from pydantic import ( + BaseModel, + ConfigDict, + Field, + SecretStr, + model_validator, +) from tokenizers import Tokenizer # type: ignore +from typing_extensions import Self logger = logging.getLogger(__name__) @@ -125,41 +127,42 @@ class MistralAIEmbeddings(BaseModel, Embeddings): model: str = "mistral-embed" - class Config: - extra = "forbid" - arbitrary_types_allowed = True - allow_population_by_field_name = True + model_config = ConfigDict( + extra="forbid", + arbitrary_types_allowed=True, + populate_by_name=True, + ) - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate configuration.""" - api_key_str = values["mistral_api_key"].get_secret_value() + api_key_str = self.mistral_api_key.get_secret_value() # todo: handle retries - if not values.get("client"): - values["client"] = httpx.Client( - base_url=values["endpoint"], + if not self.client: + self.client = httpx.Client( + base_url=self.endpoint, headers={ "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {api_key_str}", }, - timeout=values["timeout"], + timeout=self.timeout, ) # todo: handle retries and max_concurrency - if not values.get("async_client"): - values["async_client"] = httpx.AsyncClient( - base_url=values["endpoint"], + if not self.async_client: + self.async_client = httpx.AsyncClient( + base_url=self.endpoint, headers={ "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {api_key_str}", }, - timeout=values["timeout"], + timeout=self.timeout, ) - if values["tokenizer"] is None: + if self.tokenizer is None: try: - values["tokenizer"] = Tokenizer.from_pretrained( + self.tokenizer = Tokenizer.from_pretrained( "mistralai/Mixtral-8x7B-v0.1" ) except IOError: # huggingface_hub GatedRepoError @@ -169,8 +172,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings): "HF_TOKEN environment variable to download the real tokenizer. " "Falling back to a dummy tokenizer that uses `len()`." ) - values["tokenizer"] = DummyTokenizer() - return values + self.tokenizer = DummyTokenizer() + return self def _get_batches(self, texts: List[str]) -> Iterable[List[str]]: """Split a list of texts into batches of less than 16k tokens diff --git a/libs/partners/mistralai/poetry.lock b/libs/partners/mistralai/poetry.lock index c52c780aca6..a64859d4b17 100644 --- a/libs/partners/mistralai/poetry.lock +++ b/libs/partners/mistralai/poetry.lock @@ -11,9 +11,6 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} - [[package]] name = "anyio" version = "4.4.0" @@ -397,7 +394,7 @@ name = "langchain-core" version = "0.2.38" description = "Building applications with LLMs through composability" optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = ">=3.9,<4.0" files = [] develop = true @@ -1082,5 +1079,5 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" -python-versions = ">=3.8.1,<4.0" -content-hash = "11a8e8533f0ed605e14cf916957ccde5f8bf77056227fcbc152b0f644f1e45bd" +python-versions = ">=3.9,<4.0" +content-hash = "08e71710e103a4888f5d959413cfb5400301e9485027e4d0ef48a49bc82e6f10" diff --git a/libs/partners/mistralai/pyproject.toml b/libs/partners/mistralai/pyproject.toml index 6da683eb31c..dec8bcbe97b 100644 --- a/libs/partners/mistralai/pyproject.toml +++ b/libs/partners/mistralai/pyproject.toml @@ -19,11 +19,12 @@ disallow_untyped_defs = "True" "Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-mistralai%3D%3D0%22&expanded=true" [tool.poetry.dependencies] -python = ">=3.8.1,<4.0" +python = ">=3.9,<4.0" langchain-core = "^0.2.38" tokenizers = ">=0.15.1,<1" httpx = ">=0.25.2,<1" httpx-sse = ">=0.3.1,<1" +pydantic = ">2,<3" [tool.ruff.lint] select = [ "E", "F", "I", "T201",] diff --git a/libs/partners/mistralai/scripts/check_pydantic.sh b/libs/partners/mistralai/scripts/check_pydantic.sh deleted file mode 100755 index 06b5bb81ae2..00000000000 --- a/libs/partners/mistralai/scripts/check_pydantic.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -# -# This script searches for lines starting with "import pydantic" or "from pydantic" -# in tracked files within a Git repository. -# -# Usage: ./scripts/check_pydantic.sh /path/to/repository - -# Check if a path argument is provided -if [ $# -ne 1 ]; then - echo "Usage: $0 /path/to/repository" - exit 1 -fi - -repository_path="$1" - -# Search for lines matching the pattern within the specified repository -result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') - -# Check if any matching lines were found -if [ -n "$result" ]; then - echo "ERROR: The following lines need to be updated:" - echo "$result" - echo "Please replace the code with an import from langchain_core.pydantic_v1." - echo "For example, replace 'from pydantic import BaseModel'" - echo "with 'from langchain_core.pydantic_v1 import BaseModel'" - exit 1 -fi diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index 9f88d9b4ea5..dc67bd93abe 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -9,7 +9,7 @@ from langchain_core.messages import ( BaseMessageChunk, HumanMessage, ) -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel from langchain_mistralai.chat_models import ChatMistralAI diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 62b86333887..a5046db47e4 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -15,7 +15,7 @@ from langchain_core.messages import ( SystemMessage, ToolCall, ) -from langchain_core.pydantic_v1 import SecretStr +from pydantic import SecretStr from langchain_mistralai.chat_models import ( # type: ignore[import] ChatMistralAI, diff --git a/libs/partners/mistralai/tests/unit_tests/test_embeddings.py b/libs/partners/mistralai/tests/unit_tests/test_embeddings.py index 46e65eaf5fb..41023f015e2 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_embeddings.py +++ b/libs/partners/mistralai/tests/unit_tests/test_embeddings.py @@ -1,7 +1,7 @@ import os from typing import cast -from langchain_core.pydantic_v1 import SecretStr +from pydantic import SecretStr from langchain_mistralai import MistralAIEmbeddings