mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 13:59:49 +00:00
mistralai: update pydantic (#25995)
Migrated with gritql: https://github.com/eyurtsev/migrate-pydantic
This commit is contained in:
@@ -66,17 +66,19 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import secret_from_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -379,11 +381,10 @@ class ChatMistralAI(BaseChatModel):
|
||||
safe_mode: bool = False
|
||||
streaming: bool = False
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
@@ -469,47 +470,50 @@ class ChatMistralAI(BaseChatModel):
|
||||
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
return combined
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate api key, python package exists, temperature, and top_p."""
|
||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
||||
if isinstance(self.mistral_api_key, SecretStr):
|
||||
api_key_str: Optional[str] = self.mistral_api_key.get_secret_value()
|
||||
else:
|
||||
api_key_str = self.mistral_api_key
|
||||
|
||||
# todo: handle retries
|
||||
base_url_str = (
|
||||
values.get("endpoint")
|
||||
self.endpoint
|
||||
or os.environ.get("MISTRAL_BASE_URL")
|
||||
or "https://api.mistral.ai/v1"
|
||||
)
|
||||
values["endpoint"] = base_url_str
|
||||
if not values.get("client"):
|
||||
values["client"] = httpx.Client(
|
||||
self.endpoint = base_url_str
|
||||
if not self.client:
|
||||
self.client = httpx.Client(
|
||||
base_url=base_url_str,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
timeout=self.timeout,
|
||||
)
|
||||
# todo: handle retries and max_concurrency
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = httpx.AsyncClient(
|
||||
if not self.async_client:
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=base_url_str,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
if self.top_p is not None and not 0 <= self.top_p <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -728,7 +732,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
from typing import Optional
|
||||
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
@@ -759,7 +763,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
@@ -846,7 +850,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
.. code-block::
|
||||
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
answer: str
|
||||
|
@@ -1,20 +1,22 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Dict, Iterable, List
|
||||
from typing import Iterable, List
|
||||
|
||||
import httpx
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import (
|
||||
secret_from_env,
|
||||
)
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
from typing_extensions import Self
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -125,41 +127,42 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
model: str = "mistral-embed"
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
arbitrary_types_allowed=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate configuration."""
|
||||
|
||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
||||
api_key_str = self.mistral_api_key.get_secret_value()
|
||||
# todo: handle retries
|
||||
if not values.get("client"):
|
||||
values["client"] = httpx.Client(
|
||||
base_url=values["endpoint"],
|
||||
if not self.client:
|
||||
self.client = httpx.Client(
|
||||
base_url=self.endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
timeout=self.timeout,
|
||||
)
|
||||
# todo: handle retries and max_concurrency
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = httpx.AsyncClient(
|
||||
base_url=values["endpoint"],
|
||||
if not self.async_client:
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=self.endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if values["tokenizer"] is None:
|
||||
if self.tokenizer is None:
|
||||
try:
|
||||
values["tokenizer"] = Tokenizer.from_pretrained(
|
||||
self.tokenizer = Tokenizer.from_pretrained(
|
||||
"mistralai/Mixtral-8x7B-v0.1"
|
||||
)
|
||||
except IOError: # huggingface_hub GatedRepoError
|
||||
@@ -169,8 +172,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
"HF_TOKEN environment variable to download the real tokenizer. "
|
||||
"Falling back to a dummy tokenizer that uses `len()`."
|
||||
)
|
||||
values["tokenizer"] = DummyTokenizer()
|
||||
return values
|
||||
self.tokenizer = DummyTokenizer()
|
||||
return self
|
||||
|
||||
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
|
||||
"""Split a list of texts into batches of less than 16k tokens
|
||||
|
9
libs/partners/mistralai/poetry.lock
generated
9
libs/partners/mistralai/poetry.lock
generated
@@ -11,9 +11,6 @@ files = [
|
||||
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.4.0"
|
||||
@@ -397,7 +394,7 @@ name = "langchain-core"
|
||||
version = "0.2.38"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
files = []
|
||||
develop = true
|
||||
|
||||
@@ -1082,5 +1079,5 @@ zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "11a8e8533f0ed605e14cf916957ccde5f8bf77056227fcbc152b0f644f1e45bd"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "08e71710e103a4888f5d959413cfb5400301e9485027e4d0ef48a49bc82e6f10"
|
||||
|
@@ -19,11 +19,12 @@ disallow_untyped_defs = "True"
|
||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-mistralai%3D%3D0%22&expanded=true"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
python = ">=3.9,<4.0"
|
||||
langchain-core = "^0.2.38"
|
||||
tokenizers = ">=0.15.1,<1"
|
||||
httpx = ">=0.25.2,<1"
|
||||
httpx-sse = ">=0.3.1,<1"
|
||||
pydantic = ">2,<3"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "E", "F", "I", "T201",]
|
||||
|
@@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
@@ -9,7 +9,7 @@ from langchain_core.messages import (
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
|
||||
|
@@ -15,7 +15,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain_mistralai.chat_models import ( # type: ignore[import]
|
||||
ChatMistralAI,
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
|
||||
|
Reference in New Issue
Block a user