mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-27 14:26:48 +00:00
mistralai: update pydantic (#25995)
Migrated with gritql: https://github.com/eyurtsev/migrate-pydantic
This commit is contained in:
@@ -66,17 +66,19 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
parse_tool_call,
|
parse_tool_call,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import (
|
|
||||||
BaseModel,
|
|
||||||
Field,
|
|
||||||
SecretStr,
|
|
||||||
root_validator,
|
|
||||||
)
|
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import secret_from_env
|
from langchain_core.utils import secret_from_env
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -379,11 +381,10 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
safe_mode: bool = False
|
safe_mode: bool = False
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
"""Configuration for this pydantic object."""
|
populate_by_name=True,
|
||||||
|
arbitrary_types_allowed=True,
|
||||||
allow_population_by_field_name = True
|
)
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
@@ -469,47 +470,50 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
||||||
return combined
|
return combined
|
||||||
|
|
||||||
@root_validator(pre=False, skip_on_failure=True)
|
@model_validator(mode="after")
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate api key, python package exists, temperature, and top_p."""
|
"""Validate api key, python package exists, temperature, and top_p."""
|
||||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
if isinstance(self.mistral_api_key, SecretStr):
|
||||||
|
api_key_str: Optional[str] = self.mistral_api_key.get_secret_value()
|
||||||
|
else:
|
||||||
|
api_key_str = self.mistral_api_key
|
||||||
|
|
||||||
# todo: handle retries
|
# todo: handle retries
|
||||||
base_url_str = (
|
base_url_str = (
|
||||||
values.get("endpoint")
|
self.endpoint
|
||||||
or os.environ.get("MISTRAL_BASE_URL")
|
or os.environ.get("MISTRAL_BASE_URL")
|
||||||
or "https://api.mistral.ai/v1"
|
or "https://api.mistral.ai/v1"
|
||||||
)
|
)
|
||||||
values["endpoint"] = base_url_str
|
self.endpoint = base_url_str
|
||||||
if not values.get("client"):
|
if not self.client:
|
||||||
values["client"] = httpx.Client(
|
self.client = httpx.Client(
|
||||||
base_url=base_url_str,
|
base_url=base_url_str,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"Authorization": f"Bearer {api_key_str}",
|
"Authorization": f"Bearer {api_key_str}",
|
||||||
},
|
},
|
||||||
timeout=values["timeout"],
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
# todo: handle retries and max_concurrency
|
# todo: handle retries and max_concurrency
|
||||||
if not values.get("async_client"):
|
if not self.async_client:
|
||||||
values["async_client"] = httpx.AsyncClient(
|
self.async_client = httpx.AsyncClient(
|
||||||
base_url=base_url_str,
|
base_url=base_url_str,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"Authorization": f"Bearer {api_key_str}",
|
"Authorization": f"Bearer {api_key_str}",
|
||||||
},
|
},
|
||||||
timeout=values["timeout"],
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
||||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
if self.top_p is not None and not 0 <= self.top_p <= 1:
|
||||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
return values
|
return self
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@@ -728,7 +732,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain_mistralai import ChatMistralAI
|
from langchain_mistralai import ChatMistralAI
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class AnswerWithJustification(BaseModel):
|
class AnswerWithJustification(BaseModel):
|
||||||
@@ -759,7 +763,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_mistralai import ChatMistralAI
|
from langchain_mistralai import ChatMistralAI
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class AnswerWithJustification(BaseModel):
|
class AnswerWithJustification(BaseModel):
|
||||||
@@ -846,7 +850,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
from langchain_mistralai import ChatMistralAI
|
from langchain_mistralai import ChatMistralAI
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
class AnswerWithJustification(BaseModel):
|
class AnswerWithJustification(BaseModel):
|
||||||
answer: str
|
answer: str
|
||||||
|
@@ -1,20 +1,22 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, Iterable, List
|
from typing import Iterable, List
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import (
|
|
||||||
BaseModel,
|
|
||||||
Field,
|
|
||||||
SecretStr,
|
|
||||||
root_validator,
|
|
||||||
)
|
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
secret_from_env,
|
secret_from_env,
|
||||||
)
|
)
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
from tokenizers import Tokenizer # type: ignore
|
from tokenizers import Tokenizer # type: ignore
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -125,41 +127,42 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
model: str = "mistral-embed"
|
model: str = "mistral-embed"
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
extra = "forbid"
|
extra="forbid",
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed=True,
|
||||||
allow_population_by_field_name = True
|
populate_by_name=True,
|
||||||
|
)
|
||||||
|
|
||||||
@root_validator(pre=False, skip_on_failure=True)
|
@model_validator(mode="after")
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate configuration."""
|
"""Validate configuration."""
|
||||||
|
|
||||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
api_key_str = self.mistral_api_key.get_secret_value()
|
||||||
# todo: handle retries
|
# todo: handle retries
|
||||||
if not values.get("client"):
|
if not self.client:
|
||||||
values["client"] = httpx.Client(
|
self.client = httpx.Client(
|
||||||
base_url=values["endpoint"],
|
base_url=self.endpoint,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"Authorization": f"Bearer {api_key_str}",
|
"Authorization": f"Bearer {api_key_str}",
|
||||||
},
|
},
|
||||||
timeout=values["timeout"],
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
# todo: handle retries and max_concurrency
|
# todo: handle retries and max_concurrency
|
||||||
if not values.get("async_client"):
|
if not self.async_client:
|
||||||
values["async_client"] = httpx.AsyncClient(
|
self.async_client = httpx.AsyncClient(
|
||||||
base_url=values["endpoint"],
|
base_url=self.endpoint,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"Authorization": f"Bearer {api_key_str}",
|
"Authorization": f"Bearer {api_key_str}",
|
||||||
},
|
},
|
||||||
timeout=values["timeout"],
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
if values["tokenizer"] is None:
|
if self.tokenizer is None:
|
||||||
try:
|
try:
|
||||||
values["tokenizer"] = Tokenizer.from_pretrained(
|
self.tokenizer = Tokenizer.from_pretrained(
|
||||||
"mistralai/Mixtral-8x7B-v0.1"
|
"mistralai/Mixtral-8x7B-v0.1"
|
||||||
)
|
)
|
||||||
except IOError: # huggingface_hub GatedRepoError
|
except IOError: # huggingface_hub GatedRepoError
|
||||||
@@ -169,8 +172,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
"HF_TOKEN environment variable to download the real tokenizer. "
|
"HF_TOKEN environment variable to download the real tokenizer. "
|
||||||
"Falling back to a dummy tokenizer that uses `len()`."
|
"Falling back to a dummy tokenizer that uses `len()`."
|
||||||
)
|
)
|
||||||
values["tokenizer"] = DummyTokenizer()
|
self.tokenizer = DummyTokenizer()
|
||||||
return values
|
return self
|
||||||
|
|
||||||
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
|
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
|
||||||
"""Split a list of texts into batches of less than 16k tokens
|
"""Split a list of texts into batches of less than 16k tokens
|
||||||
|
9
libs/partners/mistralai/poetry.lock
generated
9
libs/partners/mistralai/poetry.lock
generated
@@ -11,9 +11,6 @@ files = [
|
|||||||
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyio"
|
name = "anyio"
|
||||||
version = "4.4.0"
|
version = "4.4.0"
|
||||||
@@ -397,7 +394,7 @@ name = "langchain-core"
|
|||||||
version = "0.2.38"
|
version = "0.2.38"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
files = []
|
files = []
|
||||||
develop = true
|
develop = true
|
||||||
|
|
||||||
@@ -1082,5 +1079,5 @@ zstd = ["zstandard (>=0.18.0)"]
|
|||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
content-hash = "11a8e8533f0ed605e14cf916957ccde5f8bf77056227fcbc152b0f644f1e45bd"
|
content-hash = "08e71710e103a4888f5d959413cfb5400301e9485027e4d0ef48a49bc82e6f10"
|
||||||
|
@@ -19,11 +19,12 @@ disallow_untyped_defs = "True"
|
|||||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-mistralai%3D%3D0%22&expanded=true"
|
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-mistralai%3D%3D0%22&expanded=true"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.9,<4.0"
|
||||||
langchain-core = "^0.2.38"
|
langchain-core = "^0.2.38"
|
||||||
tokenizers = ">=0.15.1,<1"
|
tokenizers = ">=0.15.1,<1"
|
||||||
httpx = ">=0.25.2,<1"
|
httpx = ">=0.25.2,<1"
|
||||||
httpx-sse = ">=0.3.1,<1"
|
httpx-sse = ">=0.3.1,<1"
|
||||||
|
pydantic = ">2,<3"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [ "E", "F", "I", "T201",]
|
select = [ "E", "F", "I", "T201",]
|
||||||
|
@@ -1,27 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
#
|
|
||||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
|
||||||
# in tracked files within a Git repository.
|
|
||||||
#
|
|
||||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
|
||||||
|
|
||||||
# Check if a path argument is provided
|
|
||||||
if [ $# -ne 1 ]; then
|
|
||||||
echo "Usage: $0 /path/to/repository"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
repository_path="$1"
|
|
||||||
|
|
||||||
# Search for lines matching the pattern within the specified repository
|
|
||||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
|
||||||
|
|
||||||
# Check if any matching lines were found
|
|
||||||
if [ -n "$result" ]; then
|
|
||||||
echo "ERROR: The following lines need to be updated:"
|
|
||||||
echo "$result"
|
|
||||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
|
||||||
echo "For example, replace 'from pydantic import BaseModel'"
|
|
||||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
@@ -9,7 +9,7 @@ from langchain_core.messages import (
|
|||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain_mistralai.chat_models import ChatMistralAI
|
from langchain_mistralai.chat_models import ChatMistralAI
|
||||||
|
|
||||||
|
@@ -15,7 +15,7 @@ from langchain_core.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from langchain_mistralai.chat_models import ( # type: ignore[import]
|
from langchain_mistralai.chat_models import ( # type: ignore[import]
|
||||||
ChatMistralAI,
|
ChatMistralAI,
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from langchain_mistralai import MistralAIEmbeddings
|
from langchain_mistralai import MistralAIEmbeddings
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user