mistralai: update pydantic (#25995)

Migrated with gritql: https://github.com/eyurtsev/migrate-pydantic
This commit is contained in:
ccurme
2024-09-04 13:26:17 -04:00
committed by GitHub
parent 4799ad95d0
commit 7cee7fbfad
8 changed files with 69 additions and 91 deletions

View File

@@ -66,17 +66,19 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import secret_from_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from typing_extensions import Self
logger = logging.getLogger(__name__)
@@ -379,11 +381,10 @@ class ChatMistralAI(BaseChatModel):
safe_mode: bool = False
streaming: bool = False
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)
@property
def _default_params(self) -> Dict[str, Any]:
@@ -469,47 +470,50 @@ class ChatMistralAI(BaseChatModel):
combined = {"token_usage": overall_token_usage, "model_name": self.model}
return combined
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate api key, python package exists, temperature, and top_p."""
api_key_str = values["mistral_api_key"].get_secret_value()
if isinstance(self.mistral_api_key, SecretStr):
api_key_str: Optional[str] = self.mistral_api_key.get_secret_value()
else:
api_key_str = self.mistral_api_key
# todo: handle retries
base_url_str = (
values.get("endpoint")
self.endpoint
or os.environ.get("MISTRAL_BASE_URL")
or "https://api.mistral.ai/v1"
)
values["endpoint"] = base_url_str
if not values.get("client"):
values["client"] = httpx.Client(
self.endpoint = base_url_str
if not self.client:
self.client = httpx.Client(
base_url=base_url_str,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
timeout=self.timeout,
)
# todo: handle retries and max_concurrency
if not values.get("async_client"):
values["async_client"] = httpx.AsyncClient(
if not self.async_client:
self.async_client = httpx.AsyncClient(
base_url=base_url_str,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
timeout=self.timeout,
)
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
if self.temperature is not None and not 0 <= self.temperature <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
return values
return self
def _generate(
self,
@@ -728,7 +732,7 @@ class ChatMistralAI(BaseChatModel):
from typing import Optional
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class AnswerWithJustification(BaseModel):
@@ -759,7 +763,7 @@ class ChatMistralAI(BaseChatModel):
.. code-block:: python
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
@@ -846,7 +850,7 @@ class ChatMistralAI(BaseChatModel):
.. code-block::
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
answer: str

View File

@@ -1,20 +1,22 @@
import asyncio
import logging
import warnings
from typing import Dict, Iterable, List
from typing import Iterable, List
import httpx
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import (
secret_from_env,
)
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from tokenizers import Tokenizer # type: ignore
from typing_extensions import Self
logger = logging.getLogger(__name__)
@@ -125,41 +127,42 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
model: str = "mistral-embed"
class Config:
extra = "forbid"
arbitrary_types_allowed = True
allow_population_by_field_name = True
model_config = ConfigDict(
extra="forbid",
arbitrary_types_allowed=True,
populate_by_name=True,
)
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate configuration."""
api_key_str = values["mistral_api_key"].get_secret_value()
api_key_str = self.mistral_api_key.get_secret_value()
# todo: handle retries
if not values.get("client"):
values["client"] = httpx.Client(
base_url=values["endpoint"],
if not self.client:
self.client = httpx.Client(
base_url=self.endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
timeout=self.timeout,
)
# todo: handle retries and max_concurrency
if not values.get("async_client"):
values["async_client"] = httpx.AsyncClient(
base_url=values["endpoint"],
if not self.async_client:
self.async_client = httpx.AsyncClient(
base_url=self.endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
timeout=self.timeout,
)
if values["tokenizer"] is None:
if self.tokenizer is None:
try:
values["tokenizer"] = Tokenizer.from_pretrained(
self.tokenizer = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
except IOError: # huggingface_hub GatedRepoError
@@ -169,8 +172,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
"HF_TOKEN environment variable to download the real tokenizer. "
"Falling back to a dummy tokenizer that uses `len()`."
)
values["tokenizer"] = DummyTokenizer()
return values
self.tokenizer = DummyTokenizer()
return self
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
"""Split a list of texts into batches of less than 16k tokens

View File

@@ -11,9 +11,6 @@ files = [
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
]
[package.dependencies]
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
[[package]]
name = "anyio"
version = "4.4.0"
@@ -397,7 +394,7 @@ name = "langchain-core"
version = "0.2.38"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
python-versions = ">=3.9,<4.0"
files = []
develop = true
@@ -1082,5 +1079,5 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "11a8e8533f0ed605e14cf916957ccde5f8bf77056227fcbc152b0f644f1e45bd"
python-versions = ">=3.9,<4.0"
content-hash = "08e71710e103a4888f5d959413cfb5400301e9485027e4d0ef48a49bc82e6f10"

View File

@@ -19,11 +19,12 @@ disallow_untyped_defs = "True"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-mistralai%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
python = ">=3.9,<4.0"
langchain-core = "^0.2.38"
tokenizers = ">=0.15.1,<1"
httpx = ">=0.25.2,<1"
httpx-sse = ">=0.3.1,<1"
pydantic = ">2,<3"
[tool.ruff.lint]
select = [ "E", "F", "I", "T201",]

View File

@@ -1,27 +0,0 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@@ -9,7 +9,7 @@ from langchain_core.messages import (
BaseMessageChunk,
HumanMessage,
)
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
from langchain_mistralai.chat_models import ChatMistralAI

View File

@@ -15,7 +15,7 @@ from langchain_core.messages import (
SystemMessage,
ToolCall,
)
from langchain_core.pydantic_v1 import SecretStr
from pydantic import SecretStr
from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI,

View File

@@ -1,7 +1,7 @@
import os
from typing import cast
from langchain_core.pydantic_v1 import SecretStr
from pydantic import SecretStr
from langchain_mistralai import MistralAIEmbeddings