multiple: pydantic 2 compatibility, v0.3 (#26443)

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com>
Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com>
Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com>
Co-authored-by: ZhangShenao <15201440436@163.com>
Co-authored-by: Friso H. Kingma <fhkingma@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
Co-authored-by: Morgante Pell <morgantep@google.com>
This commit is contained in:
Erick Friis
2024-09-13 14:38:45 -07:00
committed by GitHub
parent d9813bdbbc
commit c2a3021bb0
1402 changed files with 38318 additions and 30410 deletions

View File

@@ -66,17 +66,19 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import secret_from_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from typing_extensions import Self
logger = logging.getLogger(__name__)
@@ -381,11 +383,10 @@ class ChatMistralAI(BaseChatModel):
safe_mode: bool = False
streaming: bool = False
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)
@property
def _default_params(self) -> Dict[str, Any]:
@@ -471,47 +472,50 @@ class ChatMistralAI(BaseChatModel):
combined = {"token_usage": overall_token_usage, "model_name": self.model}
return combined
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate api key, python package exists, temperature, and top_p."""
api_key_str = values["mistral_api_key"].get_secret_value()
if isinstance(self.mistral_api_key, SecretStr):
api_key_str: Optional[str] = self.mistral_api_key.get_secret_value()
else:
api_key_str = self.mistral_api_key
# todo: handle retries
base_url_str = (
values.get("endpoint")
self.endpoint
or os.environ.get("MISTRAL_BASE_URL")
or "https://api.mistral.ai/v1"
)
values["endpoint"] = base_url_str
if not values.get("client"):
values["client"] = httpx.Client(
self.endpoint = base_url_str
if not self.client:
self.client = httpx.Client(
base_url=base_url_str,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
timeout=self.timeout,
)
# todo: handle retries and max_concurrency
if not values.get("async_client"):
values["async_client"] = httpx.AsyncClient(
if not self.async_client:
self.async_client = httpx.AsyncClient(
base_url=base_url_str,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
timeout=self.timeout,
)
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
if self.temperature is not None and not 0 <= self.temperature <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
return values
return self
def _generate(
self,
@@ -730,7 +734,7 @@ class ChatMistralAI(BaseChatModel):
from typing import Optional
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class AnswerWithJustification(BaseModel):
@@ -761,7 +765,7 @@ class ChatMistralAI(BaseChatModel):
.. code-block:: python
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
@@ -848,7 +852,7 @@ class ChatMistralAI(BaseChatModel):
.. code-block::
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
answer: str

View File

@@ -1,20 +1,22 @@
import asyncio
import logging
import warnings
from typing import Dict, Iterable, List
from typing import Iterable, List
import httpx
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import (
secret_from_env,
)
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from tokenizers import Tokenizer # type: ignore
from typing_extensions import Self
logger = logging.getLogger(__name__)
@@ -125,41 +127,42 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
model: str = "mistral-embed"
class Config:
extra = "forbid"
arbitrary_types_allowed = True
allow_population_by_field_name = True
model_config = ConfigDict(
extra="forbid",
arbitrary_types_allowed=True,
populate_by_name=True,
)
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate configuration."""
api_key_str = values["mistral_api_key"].get_secret_value()
api_key_str = self.mistral_api_key.get_secret_value()
# todo: handle retries
if not values.get("client"):
values["client"] = httpx.Client(
base_url=values["endpoint"],
if not self.client:
self.client = httpx.Client(
base_url=self.endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
timeout=self.timeout,
)
# todo: handle retries and max_concurrency
if not values.get("async_client"):
values["async_client"] = httpx.AsyncClient(
base_url=values["endpoint"],
if not self.async_client:
self.async_client = httpx.AsyncClient(
base_url=self.endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
timeout=self.timeout,
)
if values["tokenizer"] is None:
if self.tokenizer is None:
try:
values["tokenizer"] = Tokenizer.from_pretrained(
self.tokenizer = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
except IOError: # huggingface_hub GatedRepoError
@@ -169,8 +172,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
"HF_TOKEN environment variable to download the real tokenizer. "
"Falling back to a dummy tokenizer that uses `len()`."
)
values["tokenizer"] = DummyTokenizer()
return values
self.tokenizer = DummyTokenizer()
return self
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
"""Split a list of texts into batches of less than 16k tokens