pydantic initial changes

This commit is contained in:
Sydney Runkle 2025-05-13 17:06:12 -07:00
parent 275e3b6710
commit 4e98c8aa47
32 changed files with 188 additions and 1132 deletions

View File

@ -15,7 +15,7 @@ test tests:
-u LANGCHAIN_API_KEY \ -u LANGCHAIN_API_KEY \
-u LANGSMITH_TRACING \ -u LANGSMITH_TRACING \
-u LANGCHAIN_PROJECT \ -u LANGCHAIN_PROJECT \
uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) uv run --group test pytest -vv -s -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
test_watch: test_watch:
env \ env \

View File

@ -23,6 +23,7 @@ from typing import (
cast, cast,
) )
from pydantic.fields import FieldInfo
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from langchain_core._api.internal import is_caller_internal from langchain_core._api.internal import is_caller_internal
@ -39,8 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
# PUBLIC API # PUBLIC API
# Last Any should be FieldInfoV1 but this leads to circular imports T = TypeVar("T", bound=Union[type, Callable[..., Any], FieldInfo])
T = TypeVar("T", bound=Union[type, Callable[..., Any], Any])
def _validate_deprecation_params( def _validate_deprecation_params(
@ -152,10 +152,6 @@ def deprecated(
_package: str = package, _package: str = package,
) -> T: ) -> T:
"""Implementation of the decorator returned by `deprecated`.""" """Implementation of the decorator returned by `deprecated`."""
from langchain_core.utils.pydantic import ( # type: ignore[attr-defined]
FieldInfoV1,
FieldInfoV2,
)
def emit_warning() -> None: def emit_warning() -> None:
"""Emit the warning.""" """Emit the warning."""
@ -228,7 +224,7 @@ def deprecated(
) )
return cast("T", obj) return cast("T", obj)
elif isinstance(obj, FieldInfoV1): elif isinstance(obj, FieldInfo):
wrapped = None wrapped = None
if not _obj_type: if not _obj_type:
_obj_type = "attribute" _obj_type = "attribute"
@ -240,28 +236,7 @@ def deprecated(
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001 def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
return cast( return cast(
"T", "T",
FieldInfoV1( FieldInfo(
default=obj.default,
default_factory=obj.default_factory,
description=new_doc,
alias=obj.alias,
exclude=obj.exclude,
),
)
elif isinstance(obj, FieldInfoV2):
wrapped = None
if not _obj_type:
_obj_type = "attribute"
if not _name:
msg = f"Field {obj} must have a name to be deprecated."
raise ValueError(msg)
old_doc = obj.description
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
return cast(
"T",
FieldInfoV2(
default=obj.default, default=obj.default,
default_factory=obj.default_factory, default_factory=obj.default_factory,
description=new_doc, description=new_doc,

View File

@ -75,7 +75,6 @@ from langchain_core.utils.function_calling import (
convert_to_json_schema, convert_to_json_schema,
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
if TYPE_CHECKING: if TYPE_CHECKING:
import uuid import uuid
@ -625,7 +624,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> dict: ) -> dict:
params = self.dict() params = self.model_dump()
params["stop"] = stop params["stop"] = stop
return {**params, **kwargs} return {**params, **kwargs}
@ -1298,7 +1297,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Return type of chat model.""" """Return type of chat model."""
@override @override
def dict(self, **kwargs: Any) -> dict: def model_dump(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM.""" """Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params) starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type starter_dict["_type"] = self._llm_type
@ -1456,9 +1455,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"schema": schema, "schema": schema,
}, },
) )
if isinstance(schema, type) and is_basemodel_subclass(schema): if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[cast("TypeBaseModel", schema)], first_tool_only=True tools=[schema], first_tool_only=True
) )
else: else:
key_name = convert_to_openai_tool(schema)["function"]["name"] key_name = convert_to_openai_tool(schema)["function"]["name"]

View File

@ -528,7 +528,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
else: else:
prompt = self._convert_input(input).to_string() prompt = self._convert_input(input).to_string()
config = ensure_config(config) config = ensure_config(config)
params = self.dict() params = self.model_dump()
params["stop"] = stop params["stop"] = stop
params = {**params, **kwargs} params = {**params, **kwargs}
options = {"stop": stop} options = {"stop": stop}
@ -598,7 +598,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
prompt = self._convert_input(input).to_string() prompt = self._convert_input(input).to_string()
config = ensure_config(config) config = ensure_config(config)
params = self.dict() params = self.model_dump()
params["stop"] = stop params["stop"] = stop
params = {**params, **kwargs} params = {**params, **kwargs}
options = {"stop": stop} options = {"stop": stop}
@ -941,7 +941,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
] * len(prompts) ] * len(prompts)
run_name_list = [cast("Optional[str]", run_name)] * len(prompts) run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts) run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict() params = self.model_dump()
params["stop"] = stop params["stop"] = stop
options = {"stop": stop} options = {"stop": stop}
( (
@ -1193,7 +1193,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
] * len(prompts) ] * len(prompts)
run_name_list = [cast("Optional[str]", run_name)] * len(prompts) run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts) run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict() params = self.model_dump()
params["stop"] = stop params["stop"] = stop
options = {"stop": stop} options = {"stop": stop}
( (
@ -1400,7 +1400,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
"""Return type of llm.""" """Return type of llm."""
@override @override
def dict(self, **kwargs: Any) -> dict: def model_dump(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM.""" """Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params) starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type starter_dict["_type"] = self._llm_type
@ -1427,7 +1427,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save # Fetch dictionary to save
prompt_dict = self.dict() prompt_dict = self.model_dump()
if save_path.suffix == ".json": if save_path.suffix == ".json":
with save_path.open("w") as f: with save_path.open("w") as f:

View File

@ -324,9 +324,9 @@ class BaseOutputParser(
) )
raise NotImplementedError(msg) raise NotImplementedError(msg)
def dict(self, **kwargs: Any) -> dict: def model_dump(self, **kwargs: Any) -> dict:
"""Return dictionary representation of output parser.""" """Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs) output_parser_dict = super().model_dump(**kwargs)
with contextlib.suppress(NotImplementedError): with contextlib.suppress(NotImplementedError):
output_parser_dict["_type"] = self._type output_parser_dict["_type"] = self._type
return output_parser_dict return output_parser_dict

View File

@ -4,11 +4,10 @@ from __future__ import annotations
import json import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import Annotated, Any, Optional, TypeVar, Union from typing import Any, Optional
import jsonpatch # type: ignore[import-untyped] import jsonpatch # type: ignore[import-untyped]
import pydantic from pydantic import BaseModel, SkipValidation
from pydantic import SkipValidation
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
@ -19,18 +18,6 @@ from langchain_core.utils.json import (
parse_json_markdown, parse_json_markdown,
parse_partial_json, parse_partial_json,
) )
from langchain_core.utils.pydantic import IS_PYDANTIC_V1
if IS_PYDANTIC_V1:
PydanticBaseModel = pydantic.BaseModel
else:
from pydantic.v1 import BaseModel
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
@ -43,18 +30,16 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
describing the difference between the previous and the current object. describing the difference between the previous and the current object.
""" """
pydantic_object: Annotated[Optional[type[TBaseModel]], SkipValidation()] = None # type: ignore[valid-type] pydantic_object: SkipValidation[Optional[type[BaseModel]]] = None
"""The Pydantic object to use for validation. """The Pydantic object to use for validation.
If None, no validation is performed.""" If None, no validation is performed."""
def _diff(self, prev: Optional[Any], next: Any) -> Any: def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch return jsonpatch.make_patch(prev, next).patch
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]: def _get_schema(self, pydantic_object: type[BaseModel]) -> dict[str, Any]:
if issubclass(pydantic_object, pydantic.BaseModel): if issubclass(pydantic_object, BaseModel):
return pydantic_object.model_json_schema() return pydantic_object.model_json_schema()
if issubclass(pydantic_object, pydantic.v1.BaseModel):
return pydantic_object.schema()
return None return None
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:

View File

@ -274,10 +274,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
pydantic_schema = self.pydantic_schema[fn_name] pydantic_schema = self.pydantic_schema[fn_name]
else: else:
pydantic_schema = self.pydantic_schema pydantic_schema = self.pydantic_schema
if hasattr(pydantic_schema, "model_validate_json"): pydantic_args = pydantic_schema.model_validate_json(_args)
pydantic_args = pydantic_schema.model_validate_json(_args)
else:
pydantic_args = pydantic_schema.parse_raw(_args)
return pydantic_args return pydantic_args

View File

@ -4,9 +4,9 @@ import copy
import json import json
import logging import logging
from json import JSONDecodeError from json import JSONDecodeError
from typing import Annotated, Any, Optional from typing import Any, Optional
from pydantic import SkipValidation, ValidationError from pydantic import BaseModel, SkipValidation, ValidationError
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall from langchain_core.messages import AIMessage, InvalidToolCall
@ -15,7 +15,6 @@ from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils.json import parse_partial_json from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import TypeBaseModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -264,7 +263,7 @@ _MAX_TOKENS_ERROR = (
class PydanticToolsParser(JsonOutputToolsParser): class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response.""" """Parse tools from OpenAI response."""
tools: Annotated[list[TypeBaseModel], SkipValidation()] tools: SkipValidation[list[type[BaseModel]]]
"""The tools to parse.""" """The tools to parse."""
# TODO: Support more granular streaming of objects. Currently only streams once all # TODO: Support more granular streaming of objects. Currently only streams once all

View File

@ -1,45 +1,33 @@
"""Output parsers using Pydantic.""" """Output parsers using Pydantic."""
import json import json
from typing import Annotated, Generic, Optional from typing import Generic, Optional, TypeVar
import pydantic from pydantic import BaseModel, SkipValidation, ValidationError
from pydantic import SkipValidation
from typing_extensions import override from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation from langchain_core.outputs import Generation
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V2, BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
PydanticBaseModel,
TBaseModel,
)
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): class PydanticOutputParser(JsonOutputParser, Generic[BaseModelT]):
"""Parse an output using a pydantic model.""" """Parse an output using a pydantic model."""
pydantic_object: Annotated[type[TBaseModel], SkipValidation()] pydantic_object: SkipValidation[type[BaseModelT]]
"""The pydantic model to parse.""" """The pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel: def _parse_obj(self, obj: dict) -> BaseModelT:
if IS_PYDANTIC_V2: try:
try: if issubclass(self.pydantic_object, BaseModel):
if issubclass(self.pydantic_object, pydantic.BaseModel): return self.pydantic_object.model_validate(obj)
return self.pydantic_object.model_validate(obj) msg = f"Unsupported model version for PydanticOutputParser: \
if issubclass(self.pydantic_object, pydantic.v1.BaseModel): {self.pydantic_object.__class__}"
return self.pydantic_object.parse_obj(obj) raise OutputParserException(msg)
msg = f"Unsupported model version for PydanticOutputParser: \ except ValidationError as e:
{self.pydantic_object.__class__}" raise self._parser_exception(e, obj) from e
raise OutputParserException(msg)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj) from e
else: # pydantic v1
try:
return self.pydantic_object.parse_obj(obj)
except pydantic.ValidationError as e:
raise self._parser_exception(e, obj) from e
def _parser_exception( def _parser_exception(
self, e: Exception, json_object: dict self, e: Exception, json_object: dict
@ -51,7 +39,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
def parse_result( def parse_result(
self, result: list[Generation], *, partial: bool = False self, result: list[Generation], *, partial: bool = False
) -> Optional[TBaseModel]: ) -> Optional[BaseModelT]:
"""Parse the result of an LLM call to a pydantic object. """Parse the result of an LLM call to a pydantic object.
Args: Args:
@ -72,7 +60,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return None return None
raise raise
def parse(self, text: str) -> TBaseModel: def parse(self, text: str) -> BaseModelT:
"""Parse the output of an LLM call to a pydantic object. """Parse the output of an LLM call to a pydantic object.
Args: Args:
@ -109,7 +97,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
@property @property
@override @override
def OutputType(self) -> type[TBaseModel]: def OutputType(self) -> type[BaseModelT]:
"""Return the pydantic model.""" """Return the pydantic model."""
return self.pydantic_object return self.pydantic_object
@ -126,7 +114,5 @@ Here is the output schema:
# Re-exporting types for backwards compatibility # Re-exporting types for backwards compatibility
__all__ = [ __all__ = [
"PydanticBaseModel",
"PydanticOutputParser", "PydanticOutputParser",
"TBaseModel",
] ]

View File

@ -125,7 +125,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
chunk_gen = ChatGenerationChunk(message=chunk) chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage): elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk( chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict()) message=BaseMessageChunk(**chunk.model_dump())
) )
else: else:
chunk_gen = GenerationChunk(text=chunk) chunk_gen = GenerationChunk(text=chunk)
@ -151,7 +151,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
chunk_gen = ChatGenerationChunk(message=chunk) chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage): elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk( chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict()) message=BaseMessageChunk(**chunk.model_dump())
) )
else: else:
chunk_gen = GenerationChunk(text=chunk) chunk_gen = GenerationChunk(text=chunk)

View File

@ -331,7 +331,7 @@ class BasePromptTemplate(
"""Return the prompt type key.""" """Return the prompt type key."""
raise NotImplementedError raise NotImplementedError
def dict(self, **kwargs: Any) -> dict: def model_dump(self, **kwargs: Any) -> dict:
"""Return dictionary representation of prompt. """Return dictionary representation of prompt.
Args: Args:
@ -369,7 +369,7 @@ class BasePromptTemplate(
raise ValueError(msg) raise ValueError(msg)
# Fetch dictionary to save # Fetch dictionary to save
prompt_dict = self.dict() prompt_dict = self.model_dump()
if "_type" not in prompt_dict: if "_type" not in prompt_dict:
msg = f"Prompt {self} does not support saving." msg = f"Prompt {self} does not support saving."
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Annotated,
Any, Any,
Optional, Optional,
TypedDict, TypedDict,
@ -886,7 +885,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
""" # noqa: E501 """ # noqa: E501
messages: Annotated[list[MessageLike], SkipValidation()] messages: SkipValidation[list[MessageLike]]
"""List of messages consisting of either message prompt templates or messages.""" """List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""

View File

@ -98,7 +98,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
raise ImportError(msg) from e raise ImportError(msg) from e
env = Environment() # noqa: S701 env = Environment() # noqa: S701
ast = env.parse(template) ast = env.parse(template)
return meta.find_undeclared_variables(ast) return meta.find_undeclared_variables(ast) # type: ignore[no-untyped-call]
def mustache_formatter(template: str, /, **kwargs: Any) -> str: def mustache_formatter(template: str, /, **kwargs: Any) -> str:

View File

@ -1,45 +0,0 @@
"""Pydantic v1 compatibility shim."""
from importlib import metadata
from langchain_core._api.deprecation import warn_deprecated
# Create namespaces for pydantic v1 and v2.
# This code must stay at the top of the file before other modules may
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
#
# This hack is done for the following reasons:
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
# both dependencies and dependents may be stuck on either version of v1 or v2.
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
# unambiguously uses either v1 or v2 API.
# * This change is easier to roll out and roll back.
try:
from pydantic.v1 import * # noqa: F403
except ImportError:
from pydantic import * # type: ignore[assignment,no-redef] # noqa: F403
try:
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
except metadata.PackageNotFoundError:
_PYDANTIC_MAJOR_VERSION = 0
warn_deprecated(
"0.3.0",
removal="1.0.0",
alternative="pydantic.v1 or pydantic",
message=(
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
"The langchain_core.pydantic_v1 module was a "
"compatibility shim for pydantic v1, and should no longer be used. "
"Please update the code to import from Pydantic directly.\n\n"
"For example, replace imports like: "
"`from langchain_core.pydantic_v1 import BaseModel`\n"
"with: `from pydantic import BaseModel`\n"
"or the v1 compatibility namespace if you are working in a code base "
"that has not been fully upgraded to pydantic 2 yet. "
"\tfrom pydantic.v1 import BaseModel\n"
),
)

View File

@ -1,26 +0,0 @@
"""Pydantic v1 compatibility shim."""
from langchain_core._api import warn_deprecated
try:
from pydantic.v1.dataclasses import * # noqa: F403
except ImportError:
from pydantic.dataclasses import * # type: ignore[no-redef] # noqa: F403
warn_deprecated(
"0.3.0",
removal="1.0.0",
alternative="pydantic.v1 or pydantic",
message=(
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
"The langchain_core.pydantic_v1 module was a "
"compatibility shim for pydantic v1, and should no longer be used. "
"Please update the code to import from Pydantic directly.\n\n"
"For example, replace imports like: "
"`from langchain_core.pydantic_v1 import BaseModel`\n"
"with: `from pydantic import BaseModel`\n"
"or the v1 compatibility namespace if you are working in a code base "
"that has not been fully upgraded to pydantic 2 yet. "
"\tfrom pydantic.v1 import BaseModel\n"
),
)

View File

@ -1,26 +0,0 @@
"""Pydantic v1 compatibility shim."""
from langchain_core._api import warn_deprecated
try:
from pydantic.v1.main import * # noqa: F403
except ImportError:
from pydantic.main import * # type: ignore[assignment,no-redef] # noqa: F403
warn_deprecated(
"0.3.0",
removal="1.0.0",
alternative="pydantic.v1 or pydantic",
message=(
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
"The langchain_core.pydantic_v1 module was a "
"compatibility shim for pydantic v1, and should no longer be used. "
"Please update the code to import from Pydantic directly.\n\n"
"For example, replace imports like: "
"`from langchain_core.pydantic_v1 import BaseModel`\n"
"with: `from pydantic import BaseModel`\n"
"or the v1 compatibility namespace if you are working in a code base "
"that has not been fully upgraded to pydantic 2 yet. "
"\tfrom pydantic.v1 import BaseModel\n"
),
)

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import inspect
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@ -19,13 +18,13 @@ from typing import (
) )
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass from pydantic import BaseModel
from langchain_core.utils.pydantic import _IgnoreUnserializable
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
from pydantic import BaseModel
from langchain_core.runnables.base import Runnable as RunnableType from langchain_core.runnables.base import Runnable as RunnableType
@ -233,7 +232,7 @@ def node_data_json(
"name": node_data_str(node.id, node.data), "name": node_data_str(node.id, node.data),
}, },
} }
elif inspect.isclass(node.data) and is_basemodel_subclass(node.data): elif isinstance(node.data, type) and issubclass(node.data, BaseModel):
json = ( json = (
{ {
"type": "schema", "type": "schema",

View File

@ -33,9 +33,6 @@ from pydantic import (
model_validator, model_validator,
validate_arguments, validate_arguments,
) )
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from pydantic.v1 import validate_arguments as validate_arguments_v1
from typing_extensions import override from typing_extensions import override
from langchain_core._api import deprecated from langchain_core._api import deprecated
@ -59,14 +56,7 @@ from langchain_core.utils.function_calling import (
_parse_google_docstring, _parse_google_docstring,
_py_38_safe_origin, _py_38_safe_origin,
) )
from langchain_core.utils.pydantic import ( from langchain_core.utils.pydantic import _create_subset_model, get_fields
TypeBaseModel,
_create_subset_model,
get_fields,
is_basemodel_subclass,
is_pydantic_v1_subclass,
is_pydantic_v2_subclass,
)
if TYPE_CHECKING: if TYPE_CHECKING:
import uuid import uuid
@ -165,36 +155,6 @@ def _infer_arg_descriptions(
return description, arg_descriptions return description, arg_descriptions
def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bool:
"""Determine if a type annotation is a Pydantic model."""
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
try:
return issubclass(annotation, base_model_class)
except TypeError:
return False
def _function_annotations_are_pydantic_v1(
signature: inspect.Signature, func: Callable
) -> bool:
"""Determine if all Pydantic annotations in a function signature are from V1."""
any_v1_annotations = any(
_is_pydantic_annotation(parameter.annotation, pydantic_version="v1")
for parameter in signature.parameters.values()
)
any_v2_annotations = any(
_is_pydantic_annotation(parameter.annotation, pydantic_version="v2")
for parameter in signature.parameters.values()
)
if any_v1_annotations and any_v2_annotations:
msg = (
f"Function {func} contains a mix of Pydantic v1 and v2 annotations. "
"Only one version of Pydantic annotations per function is supported."
)
raise NotImplementedError(msg)
return any_v1_annotations and not any_v2_annotations
class _SchemaConfig: class _SchemaConfig:
"""Configuration for the pydantic model. """Configuration for the pydantic model.
@ -241,16 +201,13 @@ def create_schema_from_function(
""" """
sig = inspect.signature(func) sig = inspect.signature(func)
if _function_annotations_are_pydantic_v1(sig, func): # https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments_v1(func, config=_SchemaConfig) # type: ignore[call-overload] with warnings.catch_warnings():
else: # We are using deprecated functionality here.
# https://docs.pydantic.dev/latest/usage/validation_decorator/ # This code should be re-written to simply construct a pydantic model
with warnings.catch_warnings(): # using inspect.signature and create_model.
# We are using deprecated functionality here. warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
# This code should be re-written to simply construct a pydantic model validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
# using inspect.signature and create_model.
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
# Let's ignore `self` and `cls` arguments for class and instance methods # Let's ignore `self` and `cls` arguments for class and instance methods
# If qualified name has a ".", then it likely belongs in a class namespace # If qualified name has a ".", then it likely belongs in a class namespace
@ -321,7 +278,7 @@ class ToolException(Exception): # noqa: N818
""" """
ArgsSchema = Union[TypeBaseModel, dict[str, Any]] ArgsSchema = Union[type[BaseModel], dict[str, Any]]
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]): class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
@ -361,7 +318,7 @@ class ChildTool(BaseTool):
You can provide few-shot examples as a part of the description. You can provide few-shot examples as a part of the description.
""" """
args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field( args_schema: SkipValidation[Optional[ArgsSchema]] = Field(
default=None, description="The tool schema." default=None, description="The tool schema."
) )
"""Pydantic model class to validate and parse the tool's input arguments. """Pydantic model class to validate and parse the tool's input arguments.
@ -370,8 +327,6 @@ class ChildTool(BaseTool):
- A subclass of pydantic.BaseModel. - A subclass of pydantic.BaseModel.
or or
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
or
- a JSON schema dict - a JSON schema dict
""" """
return_direct: bool = False return_direct: bool = False
@ -414,7 +369,7 @@ class ChildTool(BaseTool):
"""Handle the content of the ToolException thrown.""" """Handle the content of the ToolException thrown."""
handle_validation_error: Optional[ handle_validation_error: Optional[
Union[bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]] Union[bool, str, Callable[[ValidationError], str]]
] = False ] = False
"""Handle the content of the ValidationError thrown.""" """Handle the content of the ValidationError thrown."""
@ -431,7 +386,7 @@ class ChildTool(BaseTool):
if ( if (
"args_schema" in kwargs "args_schema" in kwargs
and kwargs["args_schema"] is not None and kwargs["args_schema"] is not None
and not is_basemodel_subclass(kwargs["args_schema"]) and not issubclass(kwargs["args_schema"], BaseModel)
and not isinstance(kwargs["args_schema"], dict) and not isinstance(kwargs["args_schema"], dict)
): ):
msg = ( msg = (
@ -543,10 +498,7 @@ class ChildTool(BaseTool):
) )
raise ValueError(msg) raise ValueError(msg)
key_ = next(iter(get_fields(input_args).keys())) key_ = next(iter(get_fields(input_args).keys()))
if hasattr(input_args, "model_validate"): input_args.model_validate({key_: tool_input})
input_args.model_validate({key_: tool_input})
else:
input_args.parse_obj({key_: tool_input})
return tool_input return tool_input
if input_args is not None: if input_args is not None:
if isinstance(input_args, dict): if isinstance(input_args, dict):
@ -569,24 +521,6 @@ class ChildTool(BaseTool):
tool_input[k] = tool_call_id tool_input[k] = tool_call_id
result = input_args.model_validate(tool_input) result = input_args.model_validate(tool_input)
result_dict = result.model_dump() result_dict = result.model_dump()
elif issubclass(input_args, BaseModelV1):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.parse_obj(tool_input)
result_dict = result.dict()
else: else:
msg = ( msg = (
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}" f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
@ -643,7 +577,7 @@ class ChildTool(BaseTool):
if ( if (
self.args_schema is not None self.args_schema is not None
and isinstance(self.args_schema, type) and isinstance(self.args_schema, type)
and is_basemodel_subclass(self.args_schema) and issubclass(self.args_schema, BaseModel)
and not get_fields(self.args_schema) and not get_fields(self.args_schema)
): ):
# StructuredTool with no args # StructuredTool with no args
@ -754,7 +688,7 @@ class ChildTool(BaseTool):
content, artifact = response content, artifact = response
else: else:
content = response content = response
except (ValidationError, ValidationErrorV1) as e: except ValidationError as e:
if not self.handle_validation_error: if not self.handle_validation_error:
error_to_raise = e error_to_raise = e
else: else:
@ -901,11 +835,9 @@ def _is_tool_call(x: Any) -> bool:
def _handle_validation_error( def _handle_validation_error(
e: Union[ValidationError, ValidationErrorV1], e: ValidationError,
*, *,
flag: Union[ flag: Union[Literal[True], str, Callable[[ValidationError], str]],
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> str: ) -> str:
if isinstance(flag, bool): if isinstance(flag, bool):
content = "Tool input validation error" content = "Tool input validation error"
@ -1067,7 +999,7 @@ def _is_injected_arg_type(
def get_all_basemodel_annotations( def get_all_basemodel_annotations(
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True cls: type[BaseModel], *, default_to_bound: bool = True
) -> dict[str, type]: ) -> dict[str, type]:
"""Get all annotations from a Pydantic BaseModel and its parents. """Get all annotations from a Pydantic BaseModel and its parents.
@ -1075,58 +1007,17 @@ def get_all_basemodel_annotations(
cls: The Pydantic BaseModel class. cls: The Pydantic BaseModel class.
default_to_bound: Whether to default to the bound of a TypeVar if it exists. default_to_bound: Whether to default to the bound of a TypeVar if it exists.
""" """
# cls has no subscript: cls = FooBar fields = cls.model_fields
if isinstance(cls, type): alias_map = {field.alias: name for name, field in fields.items() if field.alias}
# Gather pydantic field objects (v2: model_fields / v1: __fields__)
fields = getattr(cls, "model_fields", {}) or getattr(cls, "__fields__", {})
alias_map = {field.alias: name for name, field in fields.items() if field.alias}
annotations: dict[str, type] = {} annotations: dict[str, type] = {}
for name, param in inspect.signature(cls).parameters.items(): for name, param in inspect.signature(cls).parameters.items():
# Exclude hidden init args added by pydantic Config. For example if # Exclude hidden init args added by pydantic's ConfigDict. For example if
# BaseModel(extra="allow") then "extra_data" will part of init sig. # BaseModel(extra="allow") then "extra_data" will part of init sig.
if fields and name not in fields and name not in alias_map: if fields and name not in fields and name not in alias_map:
continue continue
field_name = alias_map.get(name, name) field_name = alias_map.get(name, name)
annotations[field_name] = param.annotation annotations[field_name] = param.annotation
orig_bases: tuple = getattr(cls, "__orig_bases__", ())
# cls has subscript: cls = FooBar[int]
else:
annotations = get_all_basemodel_annotations(
get_origin(cls), default_to_bound=False
)
orig_bases = (cls,)
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
if not (isinstance(cls, type) and is_pydantic_v2_subclass(cls)):
# if cls = FooBar inherits from Baz[str], orig_bases will contain Baz[str]
# if cls = FooBar inherits from Baz, orig_bases will contain Baz
# if cls = FooBar[int], orig_bases will contain FooBar[int]
for parent in orig_bases:
# if class = FooBar inherits from Baz, parent = Baz
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
annotations.update(
get_all_basemodel_annotations(parent, default_to_bound=False)
)
continue
parent_origin = get_origin(parent)
# if class = FooBar inherits from non-pydantic class
if not parent_origin:
continue
# if class = FooBar inherits from Baz[str]:
# parent = Baz[str],
# parent_origin = Baz,
# generic_type_vars = (type vars in Baz)
# generic_map = {type var in Baz: str}
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ())
generic_map = dict(zip(generic_type_vars, get_args(parent)))
for field in getattr(parent_origin, "__annotations__", {}):
annotations[field] = _replace_type_vars(
annotations[field], generic_map, default_to_bound=default_to_bound
)
return { return {
k: _replace_type_vars(v, default_to_bound=default_to_bound) k: _replace_type_vars(v, default_to_bound=default_to_bound)

View File

@ -7,7 +7,6 @@ from collections.abc import Awaitable
from inspect import signature from inspect import signature
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Annotated,
Any, Any,
Callable, Callable,
Literal, Literal,
@ -15,7 +14,7 @@ from typing import (
Union, Union,
) )
from pydantic import Field, SkipValidation from pydantic import BaseModel, Field, SkipValidation
from typing_extensions import override from typing_extensions import override
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -30,7 +29,6 @@ from langchain_core.tools.base import (
_get_runnable_config_param, _get_runnable_config_param,
create_schema_from_function, create_schema_from_function,
) )
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.messages import ToolCall from langchain_core.messages import ToolCall
@ -40,9 +38,7 @@ class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs.""" """Tool that can operate on any number of inputs."""
description: str = "" description: str = ""
args_schema: Annotated[ArgsSchema, SkipValidation()] = Field( args_schema: SkipValidation[ArgsSchema] = Field(..., description="The tool schema.")
..., description="The tool schema."
)
"""The input arguments' schema.""" """The input arguments' schema."""
func: Optional[Callable[..., Any]] = None func: Optional[Callable[..., Any]] = None
"""The function to run when the tool is called.""" """The function to run when the tool is called."""
@ -196,7 +192,7 @@ class StructuredTool(BaseTool):
if description is None and not parse_docstring: if description is None and not parse_docstring:
description_ = source_function.__doc__ or None description_ = source_function.__doc__ or None
if description_ is None and args_schema: if description_ is None and args_schema:
if isinstance(args_schema, type) and is_basemodel_subclass(args_schema): if isinstance(args_schema, type) and issubclass(args_schema, BaseModel):
description_ = args_schema.__doc__ or None description_ = args_schema.__doc__ or None
elif isinstance(args_schema, dict): elif isinstance(args_schema, dict):
description_ = args_schema.get("description") description_ = args_schema.get("description")

View File

@ -2,8 +2,8 @@
from __future__ import annotations from __future__ import annotations
import datetime
import warnings import warnings
from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional
from uuid import UUID from uuid import UUID
@ -32,7 +32,7 @@ def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
class TracerSessionV1Base(BaseModelV1): class TracerSessionV1Base(BaseModelV1):
"""Base class for TracerSessionV1.""" """Base class for TracerSessionV1."""
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow) start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
name: Optional[str] = None name: Optional[str] = None
extra: Optional[dict[str, Any]] = None extra: Optional[dict[str, Any]] = None
@ -69,8 +69,8 @@ class BaseRun(BaseModelV1):
uuid: str uuid: str
parent_uuid: Optional[str] = None parent_uuid: Optional[str] = None
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow) start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow) end_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
extra: Optional[dict[str, Any]] = None extra: Optional[dict[str, Any]] = None
execution_order: int execution_order: int
child_execution_order: int child_execution_order: int

View File

@ -20,13 +20,11 @@ from typing import (
) )
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
from langchain_core._api import beta, deprecated from langchain_core._api import beta, deprecated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.utils.json_schema import dereference_refs from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@ -150,9 +148,7 @@ def _convert_pydantic_to_openai_function(
The function description. The function description.
""" """
if hasattr(model, "model_json_schema"): if hasattr(model, "model_json_schema"):
schema = model.model_json_schema() # Pydantic 2 schema = model.model_json_schema()
elif hasattr(model, "schema"):
schema = model.schema() # Pydantic 1
else: else:
msg = "Model must be a Pydantic model." msg = "Model must be a Pydantic model."
raise TypeError(msg) raise TypeError(msg)
@ -249,6 +245,7 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
"type[BaseModel]", "type[BaseModel]",
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited), _convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
) )
print(model)
return _convert_pydantic_to_openai_function(model) return _convert_pydantic_to_openai_function(model)
@ -336,7 +333,7 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
return _convert_json_schema_to_openai_function( return _convert_json_schema_to_openai_function(
tool.tool_call_schema, name=tool.name, description=tool.description tool.tool_call_schema, name=tool.name, description=tool.description
) )
if issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)): if issubclass(tool.tool_call_schema, BaseModel):
return _convert_pydantic_to_openai_function( return _convert_pydantic_to_openai_function(
tool.tool_call_schema, name=tool.name, description=tool.description tool.tool_call_schema, name=tool.name, description=tool.description
) )
@ -466,7 +463,7 @@ def convert_to_openai_function(
oai_function["description"] = function_copy.pop("description") oai_function["description"] = function_copy.pop("description")
if function_copy and "properties" in function_copy: if function_copy and "properties" in function_copy:
oai_function["parameters"] = function_copy oai_function["parameters"] = function_copy
elif isinstance(function, type) and is_basemodel_subclass(function): elif isinstance(function, type) and issubclass(function, BaseModel):
oai_function = cast("dict", _convert_pydantic_to_openai_function(function)) oai_function = cast("dict", _convert_pydantic_to_openai_function(function))
elif is_typeddict(function): elif is_typeddict(function):
oai_function = cast( oai_function = cast(

View File

@ -2,173 +2,39 @@
from __future__ import annotations from __future__ import annotations
import inspect
import textwrap import textwrap
import warnings import warnings
from contextlib import nullcontext from contextlib import nullcontext
from functools import lru_cache, wraps from functools import lru_cache, wraps
from types import GenericAlias
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Optional, Optional,
TypeVar,
Union,
cast, cast,
overload,
) )
import pydantic
from packaging import version from packaging import version
from pydantic import ( from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator
BaseModel,
ConfigDict,
PydanticDeprecationWarning,
RootModel,
root_validator,
)
from pydantic import ( from pydantic import (
create_model as _create_model_base, create_model as _create_model_base,
) )
from pydantic.fields import FieldInfo as FieldInfoV2 from pydantic.fields import FieldInfo
from pydantic.json_schema import ( from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE, DEFAULT_REF_TEMPLATE,
GenerateJsonSchema, GenerateJsonSchema,
JsonSchemaMode, JsonSchemaMode,
JsonSchemaValue, JsonSchemaValue,
) )
from pydantic.version import VERSION as PYDANTIC_VERSION_STRING
from typing_extensions import override from typing_extensions import override
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic_core import core_schema from pydantic_core import core_schema
try: PYDANTIC_VERSION = version.parse(PYDANTIC_VERSION_STRING)
import pydantic
PYDANTIC_VERSION = version.parse(pydantic.__version__)
except ImportError:
PYDANTIC_VERSION = version.parse("0.0.0")
def get_pydantic_major_version() -> int:
"""DEPRECATED - Get the major version of Pydantic.
Use PYDANTIC_VERSION.major instead.
"""
warnings.warn(
"get_pydantic_major_version is deprecated. Use PYDANTIC_VERSION.major instead.",
DeprecationWarning,
stacklevel=2,
)
return PYDANTIC_VERSION.major
PYDANTIC_MAJOR_VERSION = PYDANTIC_VERSION.major
PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
IS_PYDANTIC_V1 = PYDANTIC_VERSION.major == 1
IS_PYDANTIC_V2 = PYDANTIC_VERSION.major == 2
if IS_PYDANTIC_V1:
from pydantic.fields import FieldInfo as FieldInfoV1
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = type[BaseModel]
elif IS_PYDANTIC_V2:
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore[misc]
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
def is_pydantic_v1_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
if IS_PYDANTIC_V1:
return True
if IS_PYDANTIC_V2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV1):
return True
return False
def is_pydantic_v2_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
from pydantic import BaseModel
return IS_PYDANTIC_V2 and issubclass(cls, BaseModel)
def is_basemodel_subclass(cls: type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel.
Check if the given class is a subclass of any of the following:
* pydantic.BaseModel in Pydantic 1.x
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
# Before we can use issubclass on the cls we need to check if it is a class
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
return False
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
if issubclass(cls, BaseModelV1Proper):
return True
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV2):
return True
if issubclass(cls, BaseModelV1):
return True
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
return False
def is_basemodel_instance(obj: Any) -> bool:
"""Check if the given class is an instance of Pydantic BaseModel.
Check if the given class is an instance of any of the following:
* pydantic.BaseModel in Pydantic 1.x
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
if isinstance(obj, BaseModelV1Proper):
return True
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
if isinstance(obj, BaseModelV2):
return True
if isinstance(obj, BaseModelV1):
return True
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
return False
# How to type hint this? # How to type hint this?
def pre_init(func: Callable) -> Any: def pre_init(func: Callable) -> Any:
@ -180,50 +46,35 @@ def pre_init(func: Callable) -> Any:
Returns: Returns:
Any: The decorated function. Any: The decorated function.
""" """
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)
@root_validator(pre=True) @model_validator(mode="before")
@wraps(func) @wraps(func)
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]: def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
"""Decorator to run a function before model initialization. """Decorator to run a function before model initialization.
Args: Args:
cls (Type[BaseModel]): The model class. cls (Type[BaseModel]): The model class.
values (dict[str, Any]): The values to initialize the model with. values (dict[str, Any]): The values to initialize the model with.
Returns: Returns:
dict[str, Any]: The values to initialize the model with. dict[str, Any]: The values to initialize the model with.
""" """
# Insert default values # Insert default values
fields = cls.model_fields fields = cls.model_fields
for name, field_info in fields.items(): for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled if cls.model_config.get("populate_by_name") and field_info.alias in values:
# If yes, then set the field name to the alias values[name] = values.pop(field_info.alias)
if (
hasattr(cls, "Config")
and hasattr(cls.Config, "allow_population_by_field_name")
and cls.Config.allow_population_by_field_name
and field_info.alias in values
):
values[name] = values.pop(field_info.alias)
if (
hasattr(cls, "model_config")
and cls.model_config.get("populate_by_name")
and field_info.alias in values
):
values[name] = values.pop(field_info.alias)
if ( if (
name not in values or values[name] is None name not in values or values[name] is None
) and not field_info.is_required(): ) and not field_info.is_required():
if field_info.default_factory is not None: if field_info.default_factory is not None:
values[name] = field_info.default_factory() # type: ignore[call-arg] values[name] = field_info.default_factory() # type: ignore[call-arg]
else: else:
values[name] = field_info.default values[name] = field_info.default
# Call the decorated function # Call the decorated function
return func(cls, values) return func(cls, values)
return wrapper return wrapper
@ -241,55 +92,15 @@ class _IgnoreUnserializable(GenerateJsonSchema):
return {} return {}
def _create_subset_model_v1(
name: str,
model: type[BaseModel],
field_names: list,
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
if IS_PYDANTIC_V1:
from pydantic import create_model
elif IS_PYDANTIC_V2:
from pydantic.v1 import create_model # type: ignore[no-redef]
else:
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
raise NotImplementedError(msg)
fields = {}
for field_name in field_names:
# Using pydantic v1 so can access __fields__ as a dict.
field = model.__fields__[field_name] # type: ignore[index]
t = (
# this isn't perfect but should work for most functions
field.outer_type_
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
if descriptions and field_name in descriptions:
field.field_info.description = descriptions[field_name]
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore[call-overload]
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
def _create_subset_model_v2( def _create_subset_model_v2(
name: str, name: str,
model: type[pydantic.BaseModel], model: type[BaseModel],
field_names: list[str], field_names: list[str],
*, *,
descriptions: Optional[dict] = None, descriptions: Optional[dict] = None,
fn_description: Optional[str] = None, fn_description: Optional[str] = None,
) -> type[pydantic.BaseModel]: ) -> type[BaseModel]:
"""Create a pydantic model with a subset of the model fields.""" """Create a pydantic model with a subset of the model fields."""
from pydantic import create_model
from pydantic.fields import FieldInfo
descriptions_ = descriptions or {} descriptions_ = descriptions or {}
fields = {} fields = {}
for field_name in field_names: for field_name in field_names:
@ -300,7 +111,7 @@ def _create_subset_model_v2(
field_info.metadata = field.metadata field_info.metadata = field.metadata
fields[field_name] = (field.annotation, field_info) fields[field_name] = (field.annotation, field_info)
rtn = create_model( # type: ignore[call-overload] rtn = _create_model_base(
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True) name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
) )
@ -319,89 +130,34 @@ def _create_subset_model_v2(
return rtn return rtn
# Private functionality to create a subset model that's compatible across
# different versions of pydantic.
# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.
# However, can't find a way to type hint this.
def _create_subset_model( def _create_subset_model(
name: str, name: str,
model: TypeBaseModel, model: type[BaseModel],
field_names: list[str], field_names: list[str],
*, *,
descriptions: Optional[dict] = None, descriptions: Optional[dict] = None,
fn_description: Optional[str] = None, fn_description: Optional[str] = None,
) -> type[BaseModel]: ) -> type[BaseModel]:
"""Create subset model using the same pydantic version as the input model.""" """Create subset model using the same pydantic version as the input model."""
if IS_PYDANTIC_V1: return _create_subset_model_v2(
return _create_subset_model_v1( name,
name, model,
model, field_names,
field_names, descriptions=descriptions,
descriptions=descriptions, fn_description=fn_description,
fn_description=fn_description, )
)
if IS_PYDANTIC_V2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(model, BaseModelV1):
return _create_subset_model_v1(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
return _create_subset_model_v2(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
msg = f"Unsupported pydantic version: {PYDANTIC_VERSION.major}"
raise NotImplementedError(msg)
if IS_PYDANTIC_V2: def get_fields(
from pydantic import BaseModel as BaseModelV2 model: type[BaseModel],
from pydantic.v1 import BaseModel as BaseModelV1 ) -> dict[str, FieldInfo]:
"""Get the field names of a Pydantic model."""
@overload try:
def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ... return model.model_fields
except AttributeError as exc:
@overload
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
@overload
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
def get_fields(
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields
if hasattr(model, "__fields__"):
return model.__fields__ # type: ignore[return-value]
msg = f"Expected a Pydantic model. Got {type(model)}" msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg) raise TypeError(msg) from exc
elif IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1_
def get_fields( # type: ignore[no-redef]
model: Union[type[BaseModelV1_], BaseModelV1_],
) -> dict[str, FieldInfoV1]:
"""Get the field names of a Pydantic model."""
return model.__fields__ # type: ignore[return-value]
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
_SchemaConfig = ConfigDict( _SchemaConfig = ConfigDict(
arbitrary_types_allowed=True, frozen=True, protected_namespaces=() arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
@ -458,17 +214,6 @@ def _create_root_model(
if default_ is not NO_DEFAULT: if default_ is not NO_DEFAULT:
base_class_attributes["root"] = default_ base_class_attributes["root"] = default_
with warnings.catch_warnings(): with warnings.catch_warnings():
try:
if (
isinstance(type_, type)
and not isinstance(type_, GenericAlias)
and issubclass(type_, BaseModelV1)
):
warnings.filterwarnings(
action="ignore", category=PydanticDeprecationWarning
)
except TypeError:
pass
custom_root_type = type(name, (RootModel,), base_class_attributes) custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast("type[BaseModel]", custom_root_type) return cast("type[BaseModel]", custom_root_type)
@ -545,9 +290,6 @@ _RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]: def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
"""This remaps fields to avoid colliding with internal pydantic fields.""" """This remaps fields to avoid colliding with internal pydantic fields."""
from pydantic import Field
from pydantic.fields import FieldInfo
remapped = {} remapped = {}
for key, value in field_definitions.items(): for key, value in field_definitions.items():
if key.startswith("_") or key in _RESERVED_NAMES: if key.startswith("_") or key in _RESERVED_NAMES:
@ -595,7 +337,7 @@ def create_model_v2(
root: Type for a root model (RootModel) root: Type for a root model (RootModel)
Returns: Returns:
Type[BaseModel]: The created model. type[BaseModel]: The created model.
""" """
field_definitions = field_definitions or {} field_definitions = field_definitions or {}

View File

@ -15,10 +15,6 @@ from pydantic import SecretStr
from requests import HTTPError, Response from requests import HTTPError, Response
from typing_extensions import override from typing_extensions import override
from langchain_core.utils.pydantic import (
is_pydantic_v1_subclass,
)
def xor_args(*arg_groups: tuple[str, ...]) -> Callable: def xor_args(*arg_groups: tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive.". """Validate specified keyword args are mutually exclusive.".
@ -206,16 +202,10 @@ def get_pydantic_field_names(pydantic_cls: Any) -> set[str]:
set[str]: Field names. set[str]: Field names.
""" """
all_required_field_names = set() all_required_field_names = set()
if is_pydantic_v1_subclass(pydantic_cls): for name, field in pydantic_cls.model_fields.items():
for field in pydantic_cls.__fields__.values(): all_required_field_names.add(name)
all_required_field_names.add(field.name) if field.alias:
if field.has_alias: all_required_field_names.add(field.alias)
all_required_field_names.add(field.alias)
else: # Assuming pydantic 2 for now
for name, field in pydantic_cls.model_fields.items():
all_required_field_names.add(name)
if field.alias:
all_required_field_names.add(field.alias)
return all_required_field_names return all_required_field_names

View File

@ -125,7 +125,7 @@ filterwarnings = [ "ignore::langchain_core._api.beta_decorator.LangChainBetaWarn
asyncio_default_fixture_loop_scope = "function" asyncio_default_fixture_loop_scope = "function"
[tool.ruff.lint.pep8-naming] [tool.ruff.lint.pep8-naming]
classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",] classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator"]
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"langchain_core/utils/mustache.py" = [ "PLW0603",] "langchain_core/utils/mustache.py" = [ "PLW0603",]

View File

@ -16,10 +16,6 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser, PydanticToolsParser,
) )
from langchain_core.outputs import ChatGeneration from langchain_core.outputs import ChatGeneration
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
)
STREAMED_MESSAGES: list = [ STREAMED_MESSAGES: list = [
AIMessageChunk(content=""), AIMessageChunk(content=""),
@ -532,87 +528,13 @@ async def test_partial_pydantic_output_parser_async() -> None:
assert actual == EXPECTED_STREAMED_PYDANTIC assert actual == EXPECTED_STREAMED_PYDANTIC
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2") def test_parse_with_different_pydantic() -> None:
def test_parse_with_different_pydantic_2_v1() -> None: """Test with BaseModel"""
"""Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic
class Forecast(pydantic.v1.BaseModel): class Forecast(BaseModel):
temperature: int temperature: int
forecast: str forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 2."""
import pydantic
class Forecast(pydantic.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast])
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="This test is for pydantic 1")
def test_parse_with_different_pydantic_1_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 1."""
import pydantic
class Forecast(pydantic.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) parser = PydanticToolsParser(tools=[Forecast])
message = AIMessage( message = AIMessage(
content="", content="",

View File

@ -3,35 +3,23 @@
from enum import Enum from enum import Enum
from typing import Literal, Optional from typing import Literal, Optional
import pydantic
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic.v1 import BaseModel as V1BaseModel
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import ParrotFakeChatModel from langchain_core.language_models import ParrotFakeChatModel
from langchain_core.output_parsers import PydanticOutputParser from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.output_parsers.json import JsonOutputParser from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils.pydantic import TBaseModel
class ForecastV2(pydantic.BaseModel): class Forecast(BaseModel):
temperature: int temperature: int
f_or_c: Literal["F", "C"] f_or_c: Literal["F", "C"]
forecast: str forecast: str
class ForecastV1(V1BaseModel): def test_pydantic_parser_chaining() -> None:
temperature: int
f_or_c: Literal["F", "C"]
forecast: str
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
def test_pydantic_parser_chaining(
pydantic_object: TBaseModel,
) -> None:
prompt = PromptTemplate( prompt = PromptTemplate(
template="""{{ template="""{{
"temperature": 20, "temperature": 20,
@ -43,18 +31,17 @@ def test_pydantic_parser_chaining(
model = ParrotFakeChatModel() model = ParrotFakeChatModel()
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated] parser = PydanticOutputParser(pydantic_object=Forecast)
chain = prompt | model | parser chain = prompt | model | parser
res = chain.invoke({}) res = chain.invoke({})
assert type(res) is pydantic_object assert type(res) is Forecast
assert res.f_or_c == "C" assert res.f_or_c == "C"
assert res.temperature == 20 assert res.temperature == 20
assert res.forecast == "Sunny" assert res.forecast == "Sunny"
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1]) def test_pydantic_parser_validation() -> None:
def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
bad_prompt = PromptTemplate( bad_prompt = PromptTemplate(
template="""{{ template="""{{
"temperature": "oof", "temperature": "oof",
@ -66,17 +53,14 @@ def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
model = ParrotFakeChatModel() model = ParrotFakeChatModel()
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated] parser = PydanticOutputParser(pydantic_object=Forecast)
chain = bad_prompt | model | parser chain = bad_prompt | model | parser
with pytest.raises(OutputParserException): with pytest.raises(OutputParserException):
chain.invoke({}) chain.invoke({})
# JSON output parser tests # JSON output parser tests
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1]) def test_json_parser_chaining() -> None:
def test_json_parser_chaining(
pydantic_object: TBaseModel,
) -> None:
prompt = PromptTemplate( prompt = PromptTemplate(
template="""{{ template="""{{
"temperature": 20, "temperature": 20,
@ -88,7 +72,7 @@ def test_json_parser_chaining(
model = ParrotFakeChatModel() model = ParrotFakeChatModel()
parser = JsonOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type] parser = JsonOutputParser(pydantic_object=Forecast)
chain = prompt | model | parser chain = prompt | model | parser
res = chain.invoke({}) res = chain.invoke({})

View File

@ -1,6 +1,5 @@
from functools import partial from functools import partial
from inspect import isclass from typing import Any, Union
from typing import Any, Union, cast
from pydantic import BaseModel from pydantic import BaseModel
@ -10,15 +9,14 @@ from langchain_core.load.load import loads
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langchain_core.prompts.structured import StructuredPrompt from langchain_core.prompts.structured import StructuredPrompt
from langchain_core.runnables.base import Runnable, RunnableLambda from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable( def _fake_runnable(
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any _: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
) -> Union[BaseModel, dict]: ) -> Union[BaseModel, dict]:
if isclass(schema) and is_basemodel_subclass(schema): if isinstance(schema, type) and issubclass(schema, BaseModel):
return schema(name="yo", value=value) return schema(name="yo", value=value)
params = cast("dict", schema)["parameters"] params = schema["parameters"]
return {k: 1 if k != "value" else value for k, v in params.items()} return {k: 1 if k != "value" else value for k, v in params.items()}

View File

@ -2,8 +2,6 @@ from typing import Any, Union
from pydantic import BaseModel from pydantic import BaseModel
from langchain_core.utils.pydantic import is_basemodel_subclass
# Function to replace allOf with $ref # Function to replace allOf with $ref
def replace_all_of_with_ref(schema: Any) -> None: def replace_all_of_with_ref(schema: Any) -> None:
@ -31,8 +29,6 @@ def replace_all_of_with_ref(schema: Any) -> None:
def remove_all_none_default(schema: Any) -> None: def remove_all_none_default(schema: Any) -> None:
"""Removing all none defaults. """Removing all none defaults.
Pydantic v1 did not generate these, but Pydantic v2 does.
The None defaults usually represent **NotRequired** fields, and the None value The None defaults usually represent **NotRequired** fields, and the None value
is actually **incorrect** as a value since the fields do not allow a None value. is actually **incorrect** as a value since the fields do not allow a None value.
@ -75,13 +71,9 @@ def _remove_enum(obj: Any) -> None:
def _schema(obj: Any) -> dict: def _schema(obj: Any) -> dict:
"""Return the schema of the object.""" """Return the schema of the object."""
if not is_basemodel_subclass(obj): if not (isinstance(obj, type) and issubclass(obj, BaseModel)):
msg = f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}" msg = f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}"
raise TypeError(msg) raise TypeError(msg)
# Remap to old style schema
if not hasattr(obj, "model_json_schema"): # V1 model
return obj.schema()
schema_ = obj.model_json_schema(ref_template="#/definitions/{model}") schema_ = obj.model_json_schema(ref_template="#/definitions/{model}")
if "$defs" in schema_: if "$defs" in schema_:
schema_["definitions"] = schema_["$defs"] schema_["definitions"] = schema_["$defs"]

View File

@ -21,9 +21,7 @@ from typing import (
) )
import pytest import pytest
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, ConfigDict, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain_core import tools from langchain_core import tools
@ -52,7 +50,6 @@ from langchain_core.tools import (
tool, tool,
) )
from langchain_core.tools.base import ( from langchain_core.tools.base import (
ArgsSchema,
InjectedToolArg, InjectedToolArg,
InjectedToolCallId, InjectedToolCallId,
SchemaAnnotationError, SchemaAnnotationError,
@ -65,8 +62,6 @@ from langchain_core.utils.function_calling import (
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import ( from langchain_core.utils.pydantic import (
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
_create_subset_model, _create_subset_model,
create_model_v2, create_model_v2,
) )
@ -79,9 +74,7 @@ def _get_tool_call_json_schema(tool: BaseTool) -> dict:
if isinstance(tool_schema, dict): if isinstance(tool_schema, dict):
return tool_schema return tool_schema
if hasattr(tool_schema, "model_json_schema"): return tool_schema.model_json_schema()
return tool_schema.model_json_schema()
return tool_schema.schema()
def test_unnamed_decorator() -> None: def test_unnamed_decorator() -> None:
@ -106,14 +99,6 @@ class _MockSchema(BaseModel):
arg3: Optional[dict] = None arg3: Optional[dict] = None
class _MockSchemaV1(BaseModelV1):
"""Return the arguments directly."""
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseTool): class _MockStructuredTool(BaseTool):
name: str = "structured_api" name: str = "structured_api"
args_schema: type[BaseModel] = _MockSchema args_schema: type[BaseModel] = _MockSchema
@ -205,13 +190,6 @@ def test_decorator_with_specified_schema() -> None:
assert isinstance(tool_func, BaseTool) assert isinstance(tool_func, BaseTool)
assert tool_func.args_schema == _MockSchema assert tool_func.args_schema == _MockSchema
@tool(args_schema=cast("ArgsSchema", _MockSchemaV1))
def tool_func_v1(*, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func_v1, BaseTool)
assert tool_func_v1.args_schema == _MockSchemaV1
def test_decorated_function_schema_equivalent() -> None: def test_decorated_function_schema_equivalent() -> None:
"""Test that a BaseTool without a schema meets expectations.""" """Test that a BaseTool without a schema meets expectations."""
@ -345,50 +323,6 @@ def test_structured_tool_types_parsed() -> None:
assert result == expected assert result == expected
def test_structured_tool_types_parsed_pydantic_v1() -> None:
"""Test the non-primitive types are correctly passed to structured tools."""
class SomeBaseModel(BaseModelV1):
foo: str
class AnotherBaseModel(BaseModelV1):
bar: str
@tool
def structured_tool(some_base_model: SomeBaseModel) -> AnotherBaseModel:
"""Return the arguments directly."""
return AnotherBaseModel(bar=some_base_model.foo)
assert isinstance(structured_tool, StructuredTool)
expected = AnotherBaseModel(bar="baz")
for arg in [
SomeBaseModel(foo="baz"),
SomeBaseModel(foo="baz").dict(),
]:
args = {"some_base_model": arg}
result = structured_tool.run(args)
assert result == expected
def test_structured_tool_types_parsed_pydantic_mixed() -> None:
"""Test handling of tool with mixed Pydantic version arguments."""
class SomeBaseModel(BaseModelV1):
foo: str
class AnotherBaseModel(BaseModel):
bar: str
with pytest.raises(NotImplementedError):
@tool
def structured_tool(
some_base_model: SomeBaseModel, another_base_model: AnotherBaseModel
) -> None:
"""Return the arguments directly."""
def test_base_tool_inheritance_base_schema() -> None: def test_base_tool_inheritance_base_schema() -> None:
"""Test schema is correctly inferred when inheriting from BaseTool.""" """Test schema is correctly inferred when inheriting from BaseTool."""
@ -867,7 +801,7 @@ def test_validation_error_handling_callable() -> None:
"""Test that validation errors are handled correctly.""" """Test that validation errors are handled correctly."""
expected = "foo bar" expected = "foo bar"
def handling(e: Union[ValidationError, ValidationErrorV1]) -> str: def handling(e: ValidationError) -> str:
return expected return expected
_tool = _MockStructuredTool(handle_validation_error=handling) _tool = _MockStructuredTool(handle_validation_error=handling)
@ -884,9 +818,7 @@ def test_validation_error_handling_callable() -> None:
], ],
) )
def test_validation_error_handling_non_validation_error( def test_validation_error_handling_non_validation_error(
handler: Union[ handler: Union[bool, str, Callable[[ValidationError], str]],
bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> None: ) -> None:
"""Test that validation errors are handled correctly.""" """Test that validation errors are handled correctly."""
@ -932,7 +864,7 @@ async def test_async_validation_error_handling_callable() -> None:
"""Test that validation errors are handled correctly.""" """Test that validation errors are handled correctly."""
expected = "foo bar" expected = "foo bar"
def handling(e: Union[ValidationError, ValidationErrorV1]) -> str: def handling(e: ValidationError) -> str:
return expected return expected
_tool = _MockStructuredTool(handle_validation_error=handling) _tool = _MockStructuredTool(handle_validation_error=handling)
@ -949,9 +881,7 @@ async def test_async_validation_error_handling_callable() -> None:
], ],
) )
async def test_async_validation_error_handling_non_validation_error( async def test_async_validation_error_handling_non_validation_error(
handler: Union[ handler: Union[bool, str, Callable[[ValidationError], str]],
bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> None: ) -> None:
"""Test that validation errors are handled correctly.""" """Test that validation errors are handled correctly."""
@ -1812,38 +1742,18 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
} }
def generate_models() -> list[Any]: class FooProper(BaseModel):
"""Generate a list of base models depending on the pydantic version.""" a: int
b: str
class FooProper(BaseModel):
a: int
b: str
return [FooProper]
def generate_backwards_compatible_v1() -> list[Any]: TEST_MODELS = [FooProper]
"""Generate a model with pydantic 2 from the v1 namespace."""
from pydantic.v1 import BaseModel as BaseModelV1
class FooV1Namespace(BaseModelV1):
a: int
b: str
return [FooV1Namespace]
# This generates a list of models that can be used for testing that our APIs
# behave well with either pydantic 1 proper,
# pydantic v1 from pydantic 2,
# or pydantic 2 proper.
TEST_MODELS = generate_models() + generate_backwards_compatible_v1()
@pytest.mark.parametrize("pydantic_model", TEST_MODELS) @pytest.mark.parametrize("pydantic_model", TEST_MODELS)
def test_args_schema_as_pydantic(pydantic_model: Any) -> None: def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
class SomeTool(BaseTool): class SomeTool(BaseTool):
args_schema: type[pydantic_model] = pydantic_model args_schema: type[BaseModel] = pydantic_model
def _run(self, *args: Any, **kwargs: Any) -> str: def _run(self, *args: Any, **kwargs: Any) -> str:
return "foo" return "foo"
@ -1853,11 +1763,7 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
) )
input_schema = tool.get_input_schema() input_schema = tool.get_input_schema()
input_json_schema = ( input_json_schema = input_schema.model_json_schema()
input_schema.model_json_schema()
if hasattr(input_schema, "model_json_schema")
else input_schema.schema()
)
assert input_json_schema == { assert input_json_schema == {
"properties": { "properties": {
"a": {"title": "A", "type": "integer"}, "a": {"title": "A", "type": "integer"},
@ -1882,22 +1788,13 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
def test_args_schema_explicitly_typed() -> None: def test_args_schema_explicitly_typed() -> None:
"""This should test that one can type the args schema as a pydantic model. """This should test that one can type the args schema as a pydantic model."""
Please note that this will test using pydantic 2 even though BaseTool
is a pydantic 1 model!
"""
# Check with whatever pydantic model is passed in and not via v1 namespace
from pydantic import BaseModel
class Foo(BaseModel): class Foo(BaseModel):
a: int a: int
b: str b: str
class SomeTool(BaseTool): class SomeTool(BaseTool):
# type ignoring here since we're allowing overriding a type
# signature of pydantic.v1.BaseModel with pydantic.BaseModel
# for pydantic 2!
args_schema: type[BaseModel] = Foo args_schema: type[BaseModel] = Foo
def _run(self, *args: Any, **kwargs: Any) -> str: def _run(self, *args: Any, **kwargs: Any) -> str:
@ -1944,11 +1841,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo" assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
args_schema = cast("BaseModel", foo_tool.args_schema) args_schema = cast("BaseModel", foo_tool.args_schema)
args_json_schema = ( args_json_schema = args_schema.model_json_schema()
args_schema.model_json_schema()
if hasattr(args_schema, "model_json_schema")
else args_schema.schema()
)
assert args_json_schema == { assert args_json_schema == {
"properties": { "properties": {
"a": {"title": "A", "type": "integer"}, "a": {"title": "A", "type": "integer"},
@ -1960,11 +1853,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
} }
input_schema = foo_tool.get_input_schema() input_schema = foo_tool.get_input_schema()
input_json_schema = ( input_json_schema = input_schema.model_json_schema()
input_schema.model_json_schema()
if hasattr(input_schema, "model_json_schema")
else input_schema.schema()
)
assert input_json_schema == { assert input_json_schema == {
"properties": { "properties": {
"a": {"title": "A", "type": "integer"}, "a": {"title": "A", "type": "integer"},
@ -2020,81 +1909,12 @@ def test__is_message_content_type(obj: Any, *, expected: bool) -> None:
assert _is_message_content_type(obj) is expected assert _is_message_content_type(obj) is expected
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.") def test__get_all_basemodel_annotations_v2() -> None:
@pytest.mark.parametrize("use_v1_namespace", [True, False])
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
A = TypeVar("A") A = TypeVar("A")
if use_v1_namespace: class ModelA(BaseModel, Generic[A]):
from pydantic.v1 import BaseModel as BaseModel1
class ModelA(BaseModel1, Generic[A], extra="allow"):
a: A
else:
from pydantic import BaseModel as BaseModel2
from pydantic import ConfigDict
class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef]
a: A
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
class ModelB(ModelA[str]):
b: Annotated[ModelA[dict[str, Any]], "foo"]
class Mixin:
def foo(self) -> str:
return "foo"
class ModelC(Mixin, ModelB):
c: dict
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict}
actual = get_all_basemodel_annotations(ModelC)
assert actual == expected
expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"]}
actual = get_all_basemodel_annotations(ModelB)
assert actual == expected
expected = {"a": Any}
actual = get_all_basemodel_annotations(ModelA)
assert actual == expected
expected = {"a": int}
actual = get_all_basemodel_annotations(ModelA[int])
assert actual == expected
D = TypeVar("D", bound=Union[str, int])
class ModelD(ModelC, Generic[D]):
d: Optional[D]
expected = {
"a": str,
"b": Annotated[ModelA[dict[str, Any]], "foo"],
"c": dict,
"d": Union[str, int, None],
}
actual = get_all_basemodel_annotations(ModelD)
assert actual == expected
expected = {
"a": str,
"b": Annotated[ModelA[dict[str, Any]], "foo"],
"c": dict,
"d": Union[int, None],
}
actual = get_all_basemodel_annotations(ModelD[int])
assert actual == expected
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Testing pydantic v1.")
def test__get_all_basemodel_annotations_v1() -> None:
A = TypeVar("A")
class ModelA(BaseModel, Generic[A], extra="allow"):
a: A a: A
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
class ModelB(ModelA[str]): class ModelB(ModelA[str]):
b: Annotated[ModelA[dict[str, Any]], "foo"] b: Annotated[ModelA[dict[str, Any]], "foo"]
@ -2226,14 +2046,9 @@ def test_create_retriever_tool() -> None:
) )
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Testing pydantic v2.") def test_tool_args_schema_pydantic_with_metadata() -> None:
def test_tool_args_schema_pydantic_v2_with_metadata() -> None: class Foo(BaseModel):
from pydantic import BaseModel as BaseModelV2 x: list[int] = Field(
from pydantic import Field as FieldV2
from pydantic import ValidationError as ValidationErrorV2
class Foo(BaseModelV2):
x: list[int] = FieldV2(
description="List of integers", min_length=10, max_length=15 description="List of integers", min_length=10, max_length=15
) )
@ -2260,7 +2075,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
} }
assert foo.invoke({"x": [0] * 10}) assert foo.invoke({"x": [0] * 10})
with pytest.raises(ValidationErrorV2): with pytest.raises(ValidationError):
foo.invoke({"x": [0] * 9}) foo.invoke({"x": [0] * 9})

View File

@ -746,7 +746,7 @@ def test_tool_outputs() -> None:
[ExtensionsAnnotated, TypingAnnotated], [ExtensionsAnnotated, TypingAnnotated],
ids=["typing_extensions.Annotated", "typing.Annotated"], ids=["typing_extensions.Annotated", "typing.Annotated"],
) )
def test__convert_typed_dict_to_openai_function( def test_convert_typed_dict_to_openai_function(
typed_dict: TypeAlias, annotated: TypeAlias typed_dict: TypeAlias, annotated: TypeAlias
) -> None: ) -> None:
class SubTool(typed_dict): # type: ignore[misc] class SubTool(typed_dict): # type: ignore[misc]
@ -985,7 +985,7 @@ def test__convert_typed_dict_to_openai_function(
@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict]) @pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None: def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
class Tool(typed_dict): # type: ignore[misc] class Tool(typed_dict): # type: ignore[misc]
arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not. arg1: typing.MutableSet
# Error should be raised since we're using v1 code path here # Error should be raised since we're using v1 code path here
with pytest.raises(TypeError): with pytest.raises(TypeError):

View File

@ -3,18 +3,12 @@
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
import pytest from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from langchain_core.utils.pydantic import ( from langchain_core.utils.pydantic import (
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
PYDANTIC_VERSION,
_create_subset_model_v2, _create_subset_model_v2,
create_model_v2, create_model_v2,
get_fields, get_fields,
is_basemodel_instance,
is_basemodel_subclass,
pre_init, pre_init,
) )
@ -94,52 +88,6 @@ def test_with_aliases() -> None:
assert foo.z == 2 assert foo.z == 2
def test_is_basemodel_subclass() -> None:
"""Test pydantic."""
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
assert is_basemodel_subclass(BaseModelV1Proper)
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
assert is_basemodel_subclass(BaseModelV2)
assert is_basemodel_subclass(BaseModelV1)
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
def test_is_basemodel_instance() -> None:
"""Test pydantic."""
if IS_PYDANTIC_V1:
from pydantic import BaseModel as BaseModelV1Proper
class FooV1(BaseModelV1Proper):
x: int
assert is_basemodel_instance(FooV1(x=5))
elif IS_PYDANTIC_V2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
class Foo(BaseModelV2):
x: int
assert is_basemodel_instance(Foo(x=5))
class Bar(BaseModelV1):
x: int
assert is_basemodel_instance(Bar(x=5))
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_VERSION.major}"
raise ValueError(msg)
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
def test_with_field_metadata() -> None: def test_with_field_metadata() -> None:
"""Test pydantic with field metadata.""" """Test pydantic with field metadata."""
from pydantic import BaseModel as BaseModelV2 from pydantic import BaseModel as BaseModelV2
@ -168,21 +116,7 @@ def test_with_field_metadata() -> None:
} }
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Only tests Pydantic v1")
def test_fields_pydantic_v1() -> None:
from pydantic import BaseModel
class Foo(BaseModel):
x: int
fields = get_fields(Foo)
assert fields == {"x": Foo.model_fields["x"]}
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
def test_fields_pydantic_v2_proper() -> None: def test_fields_pydantic_v2_proper() -> None:
from pydantic import BaseModel
class Foo(BaseModel): class Foo(BaseModel):
x: int x: int
@ -190,17 +124,6 @@ def test_fields_pydantic_v2_proper() -> None:
assert fields == {"x": Foo.model_fields["x"]} assert fields == {"x": Foo.model_fields["x"]}
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Only tests Pydantic v2")
def test_fields_pydantic_v1_from_2() -> None:
from pydantic.v1 import BaseModel
class Foo(BaseModel):
x: int
fields = get_fields(Foo)
assert fields == {"x": Foo.__fields__["x"]}
def test_create_model_v2() -> None: def test_create_model_v2() -> None:
"""Test that create model v2 works as expected.""" """Test that create model v2 works as expected."""
with warnings.catch_warnings(record=True) as record: with warnings.catch_warnings(record=True) as record:

View File

@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, Union
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from pydantic import SecretStr from pydantic import BaseModel, Field, SecretStr
from langchain_core import utils from langchain_core import utils
from langchain_core.utils import ( from langchain_core.utils import (
@ -16,10 +16,6 @@ from langchain_core.utils import (
guard_import, guard_import,
) )
from langchain_core.utils._merge import merge_dicts from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V1,
IS_PYDANTIC_V2,
)
from langchain_core.utils.utils import secret_from_env from langchain_core.utils.utils import secret_from_env
@ -214,39 +210,7 @@ def test_guard_import_failure(
guard_import(module_name, pip_name=pip_name, package=package) guard_import(module_name, pip_name=pip_name, package=package)
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2") def test_get_pydantic_field_names() -> None:
def test_get_pydantic_field_names_v1_in_2() -> None:
from pydantic.v1 import BaseModel as PydanticV1BaseModel
from pydantic.v1 import Field
class PydanticV1Model(PydanticV1BaseModel):
field1: str
field2: int
alias_field: int = Field(alias="aliased_field")
result = get_pydantic_field_names(PydanticV1Model)
expected = {"field1", "field2", "aliased_field", "alias_field"}
assert result == expected
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Requires pydantic 2")
def test_get_pydantic_field_names_v2_in_2() -> None:
from pydantic import BaseModel, Field
class PydanticModel(BaseModel):
field1: str
field2: int
alias_field: int = Field(alias="aliased_field")
result = get_pydantic_field_names(PydanticModel)
expected = {"field1", "field2", "aliased_field", "alias_field"}
assert result == expected
@pytest.mark.skipif(not IS_PYDANTIC_V1, reason="Requires pydantic 1")
def test_get_pydantic_field_names_v1() -> None:
from pydantic import BaseModel, Field
class PydanticModel(BaseModel): class PydanticModel(BaseModel):
field1: str field1: str
field2: int field2: int