mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
multiple: pydantic 2 compatibility, v0.3 (#26443)
Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com> Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com> Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com> Co-authored-by: ZhangShenao <15201440436@163.com> Co-authored-by: Friso H. Kingma <fhkingma@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Morgante Pell <morgantep@google.com>
This commit is contained in:
@@ -66,17 +66,19 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import secret_from_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -381,11 +383,10 @@ class ChatMistralAI(BaseChatModel):
|
||||
safe_mode: bool = False
|
||||
streaming: bool = False
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
@@ -471,47 +472,50 @@ class ChatMistralAI(BaseChatModel):
|
||||
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
return combined
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate api key, python package exists, temperature, and top_p."""
|
||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
||||
if isinstance(self.mistral_api_key, SecretStr):
|
||||
api_key_str: Optional[str] = self.mistral_api_key.get_secret_value()
|
||||
else:
|
||||
api_key_str = self.mistral_api_key
|
||||
|
||||
# todo: handle retries
|
||||
base_url_str = (
|
||||
values.get("endpoint")
|
||||
self.endpoint
|
||||
or os.environ.get("MISTRAL_BASE_URL")
|
||||
or "https://api.mistral.ai/v1"
|
||||
)
|
||||
values["endpoint"] = base_url_str
|
||||
if not values.get("client"):
|
||||
values["client"] = httpx.Client(
|
||||
self.endpoint = base_url_str
|
||||
if not self.client:
|
||||
self.client = httpx.Client(
|
||||
base_url=base_url_str,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
timeout=self.timeout,
|
||||
)
|
||||
# todo: handle retries and max_concurrency
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = httpx.AsyncClient(
|
||||
if not self.async_client:
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=base_url_str,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
if self.top_p is not None and not 0 <= self.top_p <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -730,7 +734,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
from typing import Optional
|
||||
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
@@ -761,7 +765,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
@@ -848,7 +852,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
.. code-block::
|
||||
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
answer: str
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Dict, Iterable, List
|
||||
from typing import Iterable, List
|
||||
|
||||
import httpx
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import (
|
||||
secret_from_env,
|
||||
)
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
from typing_extensions import Self
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -125,41 +127,42 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
model: str = "mistral-embed"
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
arbitrary_types_allowed=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate configuration."""
|
||||
|
||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
||||
api_key_str = self.mistral_api_key.get_secret_value()
|
||||
# todo: handle retries
|
||||
if not values.get("client"):
|
||||
values["client"] = httpx.Client(
|
||||
base_url=values["endpoint"],
|
||||
if not self.client:
|
||||
self.client = httpx.Client(
|
||||
base_url=self.endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
timeout=self.timeout,
|
||||
)
|
||||
# todo: handle retries and max_concurrency
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = httpx.AsyncClient(
|
||||
base_url=values["endpoint"],
|
||||
if not self.async_client:
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=self.endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
timeout=self.timeout,
|
||||
)
|
||||
if values["tokenizer"] is None:
|
||||
if self.tokenizer is None:
|
||||
try:
|
||||
values["tokenizer"] = Tokenizer.from_pretrained(
|
||||
self.tokenizer = Tokenizer.from_pretrained(
|
||||
"mistralai/Mixtral-8x7B-v0.1"
|
||||
)
|
||||
except IOError: # huggingface_hub GatedRepoError
|
||||
@@ -169,8 +172,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
"HF_TOKEN environment variable to download the real tokenizer. "
|
||||
"Falling back to a dummy tokenizer that uses `len()`."
|
||||
)
|
||||
values["tokenizer"] = DummyTokenizer()
|
||||
return values
|
||||
self.tokenizer = DummyTokenizer()
|
||||
return self
|
||||
|
||||
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
|
||||
"""Split a list of texts into batches of less than 16k tokens
|
||||
|
||||
Reference in New Issue
Block a user