fix(core): fix Pydantic v1 support in tools/runnable (#33698)

`BaseTool.args_schema` is documented as accepting a Pydantic v1 model,
but several code paths assumed v2 and raised when handed a v1 schema
(e.g. an `AttributeError` from calling
`model_json_schema()`/`model_fields` on a v1 model). This affected
anyone using a v1 `args_schema`, and anyone composing runnables whose
input/output schema is a v1 model.

This PR makes the tool/runnable schema-derivation code version-agnostic.

## Type contract

`TypeBaseModel` (and `PydanticBaseModel`) now include
`pydantic.v1.BaseModel`, so the type honestly reflects what tools and
runnables already accept at runtime. The public schema accessors
(`Runnable.get_input_schema`/`get_output_schema` and the
`input_schema`/`output_schema` properties) return `TypeBaseModel`.

## Version-agnostic helpers

Added to `langchain_core.utils.pydantic`, each dispatching on the
model's Pydantic version so callers don't have to:

- `model_json_schema(model)` — JSON schema for either version.
- `model_validate(model, obj)` — validation for either version.
- `get_fields(model)` — field map for either version (existing helper,
now used consistently).

Internally, direct `.model_json_schema()` / `.model_fields` calls are
replaced with these helpers (or with `get_input_jsonschema()` /
`get_output_jsonschema()`).

## Behavior change worth a close look

When deriving a schema from a v1 model (in `RunnableParallel`,
`RunnableAssign`, and `RunnableSequence` output schemas), a **required**
v1 field is now correctly carried over as required. Previously the v1
path read the field's `default` — which is `None` for a required v1
field — and silently turned required fields into optional/nullable ones;
`default_factory` fields were dropped entirely. The new
`_get_schema_field_definition` helper translates a v1 `ModelField`
faithfully (required → `...`, factory preserved) and dispatches
explicitly on the field type.

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Christophe Bornet
2026-06-12 06:18:49 +02:00
committed by GitHub
parent f6d63bc9f3
commit 0392b6bae4
17 changed files with 499 additions and 102 deletions

View File

@@ -91,7 +91,7 @@ from langchain_core.utils.function_calling import (
convert_to_json_schema,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import LC_ID_PREFIX, from_env
if TYPE_CHECKING:
@@ -2519,9 +2519,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
)
output_parser: JsonOutputToolsParser
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser = PydanticToolsParser(
tools=[cast("TypeBaseModel", schema)], first_tool_only=True
)
output_parser = PydanticToolsParser(tools=[schema], first_tool_only=True)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(

View File

@@ -6,7 +6,8 @@ import logging
from json import JSONDecodeError
from typing import Annotated, Any
from pydantic import SkipValidation, ValidationError
from pydantic import BaseModel, SkipValidation, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
@@ -17,8 +18,6 @@ from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import (
TypeBaseModel,
is_pydantic_v1_subclass,
is_pydantic_v2_subclass,
)
logger = logging.getLogger(__name__)
@@ -339,10 +338,10 @@ class PydanticToolsParser(JsonOutputToolsParser):
name_dict_v2: dict[str, TypeBaseModel] = {
tool.model_config.get("title") or tool.__name__: tool
for tool in self.tools
if is_pydantic_v2_subclass(tool)
if issubclass(tool, BaseModel)
}
name_dict_v1: dict[str, TypeBaseModel] = {
tool.__name__: tool for tool in self.tools if is_pydantic_v1_subclass(tool)
tool.__name__: tool for tool in self.tools if issubclass(tool, BaseModelV1)
}
name_dict: dict[str, TypeBaseModel] = {**name_dict_v2, **name_dict_v1}
pydantic_objects = []

View File

@@ -38,6 +38,7 @@ from typing import (
)
from pydantic import BaseModel, ConfigDict, Field, RootModel
from pydantic.fields import FieldInfo
from typing_extensions import override
from langchain_core._api import beta_decorator
@@ -97,9 +98,16 @@ from langchain_core.tracers.root_listeners import (
)
from langchain_core.utils.aiter import aclosing, atee
from langchain_core.utils.iter import safetee
from langchain_core.utils.pydantic import create_model_v2
from langchain_core.utils.pydantic import (
TypeBaseModel,
create_model_v2,
get_fields,
model_json_schema,
)
if TYPE_CHECKING:
from pydantic.v1.fields import ModelField
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@@ -364,14 +372,14 @@ class Runnable(ABC, Generic[Input, Output]):
raise TypeError(msg)
@property
def input_schema(self) -> type[BaseModel]:
def input_schema(self) -> TypeBaseModel:
"""The type of input this `Runnable` accepts specified as a Pydantic model."""
return self.get_input_schema()
def get_input_schema(
self,
config: RunnableConfig | None = None,
) -> type[BaseModel]:
) -> TypeBaseModel:
"""Get a Pydantic model that can be used to validate input to the `Runnable`.
`Runnable` objects that leverage the `configurable_fields` and
@@ -437,10 +445,10 @@ class Runnable(ABC, Generic[Input, Output]):
!!! version-added "Added in `langchain-core` 0.3.0"
"""
return self.get_input_schema(config).model_json_schema()
return model_json_schema(self.get_input_schema(config))
@property
def output_schema(self) -> type[BaseModel]:
def output_schema(self) -> TypeBaseModel:
"""Output schema.
The type of output this `Runnable` produces specified as a Pydantic model.
@@ -450,7 +458,7 @@ class Runnable(ABC, Generic[Input, Output]):
def get_output_schema(
self,
config: RunnableConfig | None = None,
) -> type[BaseModel]:
) -> TypeBaseModel:
"""Get a Pydantic model that can be used to validate output to the `Runnable`.
`Runnable` objects that leverage the `configurable_fields` and
@@ -516,7 +524,7 @@ class Runnable(ABC, Generic[Input, Output]):
!!! version-added "Added in `langchain-core` 0.3.0"
"""
return self.get_output_schema(config).model_json_schema()
return model_json_schema(self.get_output_schema(config))
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
@@ -2953,7 +2961,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
def _seq_input_schema(
steps: list[Runnable[Any, Any]], config: RunnableConfig | None
) -> type[BaseModel]:
) -> TypeBaseModel:
# Import locally to prevent circular import
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
RunnableAssign,
@@ -2971,7 +2979,7 @@ def _seq_input_schema(
"RunnableSequenceInput",
field_definitions={
k: (v.annotation, v.default)
for k, v in next_input_schema.model_fields.items()
for k, v in get_fields(next_input_schema).items()
if k not in first.mapper.steps__
},
)
@@ -2983,7 +2991,7 @@ def _seq_input_schema(
def _seq_output_schema(
steps: list[Runnable[Any, Any]], config: RunnableConfig | None
) -> type[BaseModel]:
) -> TypeBaseModel:
# Import locally to prevent circular import
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
RunnableAssign,
@@ -3002,12 +3010,12 @@ def _seq_output_schema(
"RunnableSequenceOutput",
field_definitions={
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.model_fields.items()
k: _get_schema_field_definition(v)
for k, v in get_fields(prev_output_schema).items()
},
**{
k: (v.annotation, v.default)
for k, v in mapper_output_schema.model_fields.items()
k: _get_schema_field_definition(v)
for k, v in get_fields(mapper_output_schema).items()
},
},
)
@@ -3019,19 +3027,36 @@ def _seq_output_schema(
return create_model_v2(
"RunnableSequenceOutput",
field_definitions={
k: (v.annotation, v.default)
for k, v in prev_output_schema.model_fields.items()
k: _get_schema_field_definition(v)
for k, v in get_fields(prev_output_schema).items()
if k in last.keys
},
)
field = prev_output_schema.model_fields[last.keys]
field = get_fields(prev_output_schema)[last.keys]
return create_model_v2(
"RunnableSequenceOutput", root=(field.annotation, field.default)
"RunnableSequenceOutput", root=_get_schema_field_definition(field)
)
return last.get_output_schema(config)
def _get_schema_field_definition(field: FieldInfo | ModelField) -> tuple[Any, Any]:
"""Convert a Pydantic field to a field definition for `create_model_v2`.
Handles both Pydantic v2 (`FieldInfo`) and v1 (`ModelField`) fields. v1
required fields carry a `None` default with `required=True`, so they must be
translated to the `...` sentinel rather than passed through as optional.
"""
if isinstance(field, FieldInfo):
return (field.annotation, field.default)
if field.required:
return (field.annotation, ...)
if field.default_factory is not None:
return (field.annotation, Field(default_factory=field.default_factory))
return (field.annotation, field.default)
_RUNNABLE_SEQUENCE_MIN_STEPS = 2
@@ -3212,7 +3237,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return self.last.OutputType
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
"""Get the input schema of the `Runnable`.
Args:
@@ -3225,9 +3250,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return _seq_input_schema(self.steps, config)
@override
def get_output_schema(
self, config: RunnableConfig | None = None
) -> type[BaseModel]:
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
"""Get the output schema of the `Runnable`.
Args:
@@ -3984,7 +4007,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
return Any
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
"""Get the input schema of the `Runnable`.
Args:
@@ -3995,23 +4018,26 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
"""
if all(
s.get_input_schema(config).model_json_schema().get("type", "object")
== "object"
s.get_input_jsonschema(config).get("type", "object") == "object"
for s in self.steps__.values()
):
for step in self.steps__.values():
fields = step.get_input_schema(config).model_fields
step_input_schema = step.get_input_schema(config)
fields = get_fields(step_input_schema)
root_field = fields.get("root")
if root_field is not None and root_field.annotation != Any:
return super().get_input_schema(config)
root_field = fields.get("__root__")
if root_field is not None and root_field.annotation != Any:
return step_input_schema
# This is correct, but pydantic typings/mypy don't think so.
return create_model_v2(
self.get_name("Input"),
field_definitions={
k: (v.annotation, v.default)
k: _get_schema_field_definition(v)
for step in self.steps__.values()
for k, v in step.get_input_schema(config).model_fields.items()
for k, v in get_fields(step.get_input_schema(config)).items()
if k != "__root__"
},
)
@@ -4032,6 +4058,9 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
"""
fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()}
# The return type is narrowed to `type[BaseModel]` (rather than the base
# class's `TypeBaseModel`) because this override always builds the schema
# with `create_model_v2`, so it is guaranteed to be a Pydantic v2 model.
return create_model_v2(self.get_name("Output"), field_definitions=fields)
@property
@@ -4932,7 +4961,7 @@ class RunnableLambda(Runnable[Input, Output]):
return Any
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
"""The Pydantic schema for the input to this `Runnable`.
Args:
@@ -5925,15 +5954,13 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[
)
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
if self.custom_input_type is not None:
return super().get_input_schema(config)
return self.bound.get_input_schema(merge_configs(self.config, config))
@override
def get_output_schema(
self, config: RunnableConfig | None = None
) -> type[BaseModel]:
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
if self.custom_output_type is not None:
return super().get_output_schema(config)
return self.bound.get_output_schema(merge_configs(self.config, config))

View File

@@ -13,7 +13,7 @@ from typing import (
cast,
)
from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from typing_extensions import override
from langchain_core.runnables.base import (
@@ -35,6 +35,7 @@ from langchain_core.runnables.utils import (
Output,
get_unique_config_specs,
)
from langchain_core.utils.pydantic import TypeBaseModel
_MIN_BRANCHES = 2
@@ -154,7 +155,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
return ["langchain", "schema", "runnable"]
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
runnables = (
[self.default]
+ [r for _, r in self.branches]
@@ -162,10 +163,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
)
for runnable in runnables:
if (
runnable.get_input_schema(config).model_json_schema().get("type")
is not None
):
if runnable.get_input_jsonschema(config).get("type") is not None:
return runnable.get_input_schema(config)
return super().get_input_schema(config)

View File

@@ -19,7 +19,7 @@ from typing import (
)
from weakref import WeakValueDictionary
from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from typing_extensions import override
from langchain_core.runnables.base import Runnable, RunnableSerializable
@@ -44,6 +44,7 @@ from langchain_core.runnables.utils import (
if TYPE_CHECKING:
from langchain_core.runnables.graph import Graph
from langchain_core.utils.pydantic import TypeBaseModel
class DynamicRunnable(RunnableSerializable[Input, Output]):
@@ -90,14 +91,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
return self.default.OutputType
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
runnable, config = self.prepare(config)
return runnable.get_input_schema(config)
@override
def get_output_schema(
self, config: RunnableConfig | None = None
) -> type[BaseModel]:
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
runnable, config = self.prepare(config)
return runnable.get_output_schema(config)

View File

@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Iterator, Sequence
from functools import wraps
from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from typing_extensions import override
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
@@ -28,6 +28,7 @@ from langchain_core.runnables.utils import (
coro_with_context,
get_unique_config_specs,
)
from langchain_core.utils.pydantic import TypeBaseModel
if TYPE_CHECKING:
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
@@ -118,13 +119,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return self.runnable.OutputType
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
return self.runnable.get_input_schema(config)
@override
def get_output_schema(
self, config: RunnableConfig | None = None
) -> type[BaseModel]:
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
return self.runnable.get_output_schema(config)
@property

View File

@@ -18,13 +18,15 @@ from uuid import UUID, uuid4
from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
from langchain_core.utils.pydantic import (
TypeBaseModel,
_IgnoreUnserializable,
is_basemodel_subclass,
)
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from pydantic import BaseModel
from langchain_core.runnables.base import Runnable as RunnableType
@@ -97,7 +99,7 @@ class Node(NamedTuple):
"""The unique identifier of the node."""
name: str
"""The name of the node."""
data: type[BaseModel] | RunnableType[Any, Any] | None
data: TypeBaseModel | RunnableType[Any, Any] | None
"""The data of the node."""
metadata: dict[str, Any] | None
"""Optional metadata for the node. """
@@ -177,7 +179,7 @@ class MermaidDrawMethod(Enum):
def node_data_str(
id: str,
data: type[BaseModel] | RunnableType[Any, Any] | None,
data: TypeBaseModel | RunnableType[Any, Any] | None,
) -> str:
"""Convert the data of a node to a string.
@@ -311,7 +313,7 @@ class Graph:
def add_node(
self,
data: type[BaseModel] | RunnableType[Any, Any] | None,
data: TypeBaseModel | RunnableType[Any, Any] | None,
id: str | None = None,
*,
metadata: dict[str, Any] | None = None,

View File

@@ -11,7 +11,7 @@ from typing import (
Any,
)
from pydantic import BaseModel, RootModel
from pydantic import RootModel
from typing_extensions import override
from langchain_core.runnables.base import (
@@ -19,6 +19,7 @@ from langchain_core.runnables.base import (
Runnable,
RunnableParallel,
RunnableSerializable,
_get_schema_field_definition,
)
from langchain_core.runnables.config import (
RunnableConfig,
@@ -34,7 +35,7 @@ from langchain_core.runnables.utils import (
)
from langchain_core.utils.aiter import atee
from langchain_core.utils.iter import safetee
from langchain_core.utils.pydantic import create_model_v2
from langchain_core.utils.pydantic import TypeBaseModel, create_model_v2, get_fields
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Mapping
@@ -425,7 +426,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
return super().get_name(suffix, name=name)
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
map_input_schema = self.mapper.get_input_schema(config)
if not issubclass(map_input_schema, RootModel):
# ie. it's a dict
@@ -434,9 +435,11 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
return super().get_input_schema(config)
@override
def get_output_schema(
self, config: RunnableConfig | None = None
) -> type[BaseModel]:
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
# The return type stays `TypeBaseModel` (rather than narrowing to
# `type[BaseModel]` as `RunnableParallel.get_output_schema` does) because
# the fallback branches return the mapper's output schema or delegate to
# `super().get_output_schema()`, either of which may be a Pydantic v1 model.
map_input_schema = self.mapper.get_input_schema(config)
map_output_schema = self.mapper.get_output_schema(config)
if not issubclass(map_input_schema, RootModel) and not issubclass(
@@ -444,11 +447,11 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
):
fields = {}
for name, field_info in map_input_schema.model_fields.items():
fields[name] = (field_info.annotation, field_info.default)
for name, field_info in get_fields(map_input_schema).items():
fields[name] = _get_schema_field_definition(field_info)
for name, field_info in map_output_schema.model_fields.items():
fields[name] = (field_info.annotation, field_info.default)
for name, field_info in get_fields(map_output_schema).items():
fields[name] = _get_schema_field_definition(field_info)
return create_model_v2("RunnableAssignOutput", field_definitions=fields)
if not issubclass(map_output_schema, RootModel):

View File

@@ -65,6 +65,7 @@ from langchain_core.utils.pydantic import (
is_basemodel_subclass,
is_pydantic_v1_subclass,
is_pydantic_v2_subclass,
model_json_schema,
)
if TYPE_CHECKING:
@@ -267,7 +268,7 @@ def create_schema_from_function(
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
include_injected: bool = True,
) -> type[BaseModel]:
) -> TypeBaseModel:
"""Create a Pydantic schema from a function's signature.
Args:
@@ -572,14 +573,12 @@ class ChildTool(BaseTool):
"""
if isinstance(self.args_schema, dict):
json_schema = self.args_schema
elif self.args_schema and issubclass(self.args_schema, BaseModelV1):
json_schema = self.args_schema.schema()
else:
input_schema = self.tool_call_schema
if isinstance(input_schema, dict):
json_schema = input_schema
else:
json_schema = input_schema.model_json_schema()
json_schema = model_json_schema(input_schema)
return cast("dict[str, Any]", json_schema["properties"])
@property
@@ -615,7 +614,7 @@ class ChildTool(BaseTool):
# --- Runnable ---
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
"""The tool's input schema.
Args:
@@ -693,6 +692,7 @@ class ChildTool(BaseTool):
if input_args is not None:
if isinstance(input_args, dict):
return tool_input
result: BaseModel | BaseModelV1
if issubclass(input_args, BaseModel):
# Check args_schema for InjectedToolCallId
for k, v in get_all_basemodel_annotations(input_args).items():
@@ -707,8 +707,9 @@ class ChildTool(BaseTool):
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.model_validate(tool_input)
result_dict = result.model_dump()
result_v2 = input_args.model_validate(tool_input)
result_dict = result_v2.model_dump()
result = result_v2
elif issubclass(input_args, BaseModelV1):
# Check args_schema for InjectedToolCallId
for k, v in get_all_basemodel_annotations(input_args).items():
@@ -723,8 +724,9 @@ class ChildTool(BaseTool):
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.parse_obj(tool_input)
result_dict = result.dict()
result_v1 = input_args.parse_obj(tool_input)
result_dict = result_v1.dict()
result = result_v1
else:
msg = (
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"

View File

@@ -11,6 +11,7 @@ from langchain_core.runnables import Runnable
from langchain_core.tools.base import ArgsSchema, BaseTool
from langchain_core.tools.simple import Tool
from langchain_core.tools.structured import StructuredTool
from langchain_core.utils.pydantic import TypeBaseModel
@overload
@@ -275,7 +276,7 @@ def tool(
if isinstance(dec_func, Runnable):
runnable = dec_func
if runnable.input_schema.model_json_schema().get("type") != "object":
if runnable.get_input_jsonschema().get("type") != "object":
msg = "Runnable must have an object schema."
raise ValueError(msg)
@@ -394,7 +395,7 @@ def tool(
def _get_description_from_runnable(runnable: Runnable[Any, Any]) -> str:
"""Generate a placeholder description of a `Runnable`."""
input_schema = runnable.input_schema.model_json_schema()
input_schema = runnable.get_input_jsonschema()
return f"Takes {input_schema}."
@@ -420,7 +421,7 @@ def _get_schema_from_runnable_and_arg_types(
def convert_runnable_to_tool(
runnable: Runnable[Any, Any],
args_schema: type[BaseModel] | None = None,
args_schema: TypeBaseModel | None = None,
*,
name: str | None = None,
description: str | None = None,
@@ -443,7 +444,7 @@ def convert_runnable_to_tool(
description = description or _get_description_from_runnable(runnable)
name = name or runnable.get_name()
schema = runnable.input_schema.model_json_schema()
schema = runnable.get_input_jsonschema()
if schema.get("type") == "string":
return Tool(
name=name,

View File

@@ -69,8 +69,8 @@ PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
IS_PYDANTIC_V1 = False
IS_PYDANTIC_V2 = True
PydanticBaseModel = BaseModel
TypeBaseModel = type[BaseModel]
PydanticBaseModel = BaseModel | BaseModelV1
TypeBaseModel = type[BaseModel] | type[BaseModelV1]
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
@@ -287,7 +287,7 @@ def _create_subset_model(
*,
descriptions: dict[str, str] | None = None,
fn_description: str | None = None,
) -> type[BaseModel]:
) -> TypeBaseModel:
"""Create subset model using the same pydantic version as the input model.
Returns:
@@ -347,6 +347,49 @@ def get_fields(
raise TypeError(msg)
def model_json_schema(model: TypeBaseModel) -> dict[str, Any]:
"""Return the JSON schema of a Pydantic model class of either major version.
Dispatches to the correct method for Pydantic v1 (`schema`) or v2
(`model_json_schema`), so callers holding a `TypeBaseModel` don't have to
branch on the model's version themselves.
Args:
model: The Pydantic model class.
Raises:
TypeError: If the model is not a Pydantic model class.
"""
if issubclass(model, BaseModel):
return model.model_json_schema()
if issubclass(model, BaseModelV1):
return model.schema()
msg = f"Expected a Pydantic model. Got {model}"
raise TypeError(msg)
def model_validate(model: TypeBaseModel, obj: Any) -> PydanticBaseModel:
"""Validate `obj` against a Pydantic model class of either major version.
Dispatches to the correct method for Pydantic v1 (`parse_obj`) or v2
(`model_validate`), so callers holding a `TypeBaseModel` don't have to
branch on the model's version themselves.
Args:
model: The Pydantic model class to validate against.
obj: The object to validate.
Raises:
TypeError: If the model is not a Pydantic model class.
"""
if issubclass(model, BaseModel):
return model.model_validate(obj)
if issubclass(model, BaseModelV1):
return model.parse_obj(obj)
msg = f"Expected a Pydantic model. Got {model}"
raise TypeError(msg)
_SchemaConfig = ConfigDict(
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
)

View File

@@ -175,7 +175,7 @@ def test_pydantic_output_parser_type_inference() -> None:
# Ignoring mypy error that appears in python 3.8, but not 3.11.
# This seems to be functionally correct, so we'll ignore the error.
pydantic_parser = PydanticOutputParser[SampleModel](pydantic_object=SampleModel)
schema = pydantic_parser.get_output_schema().model_json_schema()
schema = pydantic_parser.get_output_jsonschema()
assert schema == {
"properties": {

View File

@@ -1,9 +1,18 @@
import sys
from inspect import isclass
from typing import Any
import pytest
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
# pydantic.v1 models cannot be exercised under the compatibility shim on
# Python 3.14+, so v1-specific tests are skipped there.
skip_if_no_pydantic_v1 = pytest.mark.skipif(
sys.version_info >= (3, 14),
reason="pydantic.v1 namespace not supported with Python 3.14+",
)
# Function to replace allOf with $ref
def replace_all_of_with_ref(schema: Any) -> None:

View File

@@ -19,7 +19,10 @@ from uuid import UUID
import pytest
from freezegun import freeze_time
from packaging import version
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import Field as FieldV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from pytest_mock import MockerFixture
from syrupy.assertion import SnapshotAssertion
from typing_extensions import TypedDict, override
@@ -90,9 +93,17 @@ from langchain_core.tracers import (
)
from langchain_core.tracers._compat import pydantic_copy
from langchain_core.tracers.context import collect_runs
from langchain_core.utils.pydantic import PYDANTIC_VERSION
from langchain_core.utils.pydantic import (
PYDANTIC_VERSION,
TypeBaseModel,
model_validate,
)
from langchain_core.version import VERSION
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
from tests.unit_tests.pydantic_utils import (
_normalize_schema,
_schema,
skip_if_no_pydantic_v1,
)
from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
@@ -499,7 +510,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
foo_ = RunnableLambda(foo)
assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == {
assert foo_.assign(bar=lambda _: "foo").get_output_jsonschema() == {
"properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
"required": ["root", "bar"],
"title": "RunnableAssignOutput",
@@ -5820,6 +5831,167 @@ def test_runnable_typed_dict_schema() -> None:
other=other_runnable,
)
assert (
repr(parallel.input_schema.model_validate({"foo": "Y", "bar": "Z"}))
repr(model_validate(parallel.input_schema, {"foo": "Y", "bar": "Z"}))
== "RunnableParallel<foo,other>Input(root={'foo': 'Y', 'bar': 'Z'})"
)
class _RunnableWithInputSchema(Runnable[Any, Any]):
def __init__(self, input_schema: TypeBaseModel) -> None:
self._input_schema = input_schema
@property
@override
def InputType(self) -> Any:
return self._input_schema
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
_ = config
return self._input_schema
@override
def invoke(
self,
input: Any,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
return input
@skip_if_no_pydantic_v1
def test_runnable_parallel_preserves_required_v1_input_fields() -> None:
class InputModel(BaseModelV1):
a: int
b: int = 2
parallel = RunnableParallel(foo=_RunnableWithInputSchema(InputModel))
schema = cast("type[BaseModel]", parallel.input_schema)
assert schema.model_json_schema()["required"] == ["a"]
with pytest.raises(ValidationError):
schema.model_validate({})
model = schema.model_validate({"a": 1})
assert model.model_dump() == {"a": 1, "b": 2}
@skip_if_no_pydantic_v1
def test_runnable_parallel_uses_base_schema_for_v1_root_model() -> None:
class InputModel(BaseModelV1):
__root__: dict[str, int]
parallel = RunnableParallel(foo=_RunnableWithInputSchema(InputModel))
schema = parallel.input_schema
assert schema is InputModel
assert schema.schema() == {
"additionalProperties": {"type": "integer"},
"title": "InputModel",
"type": "object",
}
assert schema.parse_obj({"a": 1}).__root__ == {"a": 1}
with pytest.raises(ValidationErrorV1):
schema.parse_obj({"a": "not an int"})
@skip_if_no_pydantic_v1
def test_runnable_sequence_v1_input_schema() -> None:
"""A `RunnableSequence` exposes a Pydantic v1 first-step input schema.
Regression test: deriving the sequence schema previously assumed Pydantic v2.
"""
class InputModel(BaseModelV1):
a: int
sequence = _RunnableWithInputSchema(InputModel) | RunnableLambda(lambda x: x)
assert sequence.get_input_jsonschema()["properties"] == {
"a": {"title": "A", "type": "integer"}
}
@skip_if_no_pydantic_v1
def test_runnable_branch_v1_input_schema() -> None:
"""A `RunnableBranch` exposes a Pydantic v1 input schema.
Regression test: `get_input_schema` previously assumed Pydantic v2 when
inspecting each branch's schema.
"""
class InputModel(BaseModelV1):
a: int
branch = RunnableBranch(
(lambda _: True, _RunnableWithInputSchema(InputModel)),
_RunnableWithInputSchema(InputModel),
)
assert branch.get_input_jsonschema()["properties"] == {
"a": {"title": "A", "type": "integer"}
}
@skip_if_no_pydantic_v1
def test_runnable_parallel_preserves_v1_default_factory() -> None:
"""A v1 `default_factory` field keeps its factory in the derived schema.
Regression test for `_get_schema_field_definition`: without the dedicated
`default_factory` branch the factory is dropped and the field defaults to
`None` instead of producing the factory's value.
"""
class InputModel(BaseModelV1):
a: int
items: list[int] = FieldV1(default_factory=list)
parallel = RunnableParallel(foo=_RunnableWithInputSchema(InputModel))
schema = cast("type[BaseModel]", parallel.input_schema)
# The factory field is optional, so only `a` is required.
assert schema.model_json_schema()["required"] == ["a"]
# The factory runs when omitted, yielding `[]` rather than `None`.
assert schema.model_validate({"a": 1}).model_dump() == {"a": 1, "items": []}
@skip_if_no_pydantic_v1
def test_runnable_sequence_v1_output_schema_with_assign() -> None:
"""A sequence ending in `RunnableAssign` derives a v1 upstream output schema.
Regression test: `_seq_output_schema` previously read `.model_fields` on the
upstream schema (v2-only) and translated required v1 fields as optional.
"""
class InputModel(BaseModelV1):
a: int
sequence = _RunnableWithInputSchema(InputModel) | RunnableAssign(
RunnableParallel(bar=RunnableLambda(lambda _: "bar"))
)
schema = sequence.get_output_jsonschema()
assert set(schema["properties"]) == {"a", "bar"}
# The required v1 field survives as required rather than becoming optional.
assert "a" in schema["required"]
@skip_if_no_pydantic_v1
def test_runnable_sequence_v1_output_schema_with_pick() -> None:
"""A sequence ending in `RunnablePick` derives a v1 upstream output schema.
Regression test: `_seq_output_schema` previously read `.model_fields` on the
upstream schema (v2-only) for the `RunnablePick` branch.
"""
class InputModel(BaseModelV1):
a: int
b: int
sequence = _RunnableWithInputSchema(InputModel) | RunnablePick(["a"])
schema = sequence.get_output_jsonschema()
assert set(schema["properties"]) == {"a"}
assert "a" in schema["required"]

View File

@@ -40,6 +40,7 @@ from langchain_core.messages import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableLambda,
ensure_config,
@@ -49,6 +50,7 @@ from langchain_core.tools import (
StructuredTool,
Tool,
ToolException,
convert_runnable_to_tool,
tool,
)
from langchain_core.tools.base import (
@@ -68,11 +70,16 @@ from langchain_core.utils.function_calling import (
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import (
TypeBaseModel,
_create_subset_model,
create_model_v2,
)
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
from tests.unit_tests.pydantic_utils import (
_normalize_schema,
_schema,
skip_if_no_pydantic_v1,
)
try:
from langgraph.prebuilt import ToolRuntime # type: ignore[import-not-found]
@@ -1964,7 +1971,7 @@ def test_tool_inherited_injected_arg() -> None:
return y
tool_ = InheritedInjectedArgTool()
assert tool_.get_input_schema().model_json_schema() == {
assert tool_.get_input_jsonschema() == {
"title": "FooSchema", # Matches the title from the provided schema
"description": "foo.",
"type": "object",
@@ -2138,7 +2145,7 @@ def test_args_schema_explicitly_typed() -> None:
tool = SomeTool(name="some_tool", description="some description")
assert tool.get_input_schema().model_json_schema() == {
assert tool.get_input_jsonschema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
@@ -2160,6 +2167,92 @@ def test_args_schema_explicitly_typed() -> None:
}
@skip_if_no_pydantic_v1
def test_get_input_jsonschema_v1_args_schema() -> None:
"""`get_input_jsonschema()` works for a tool with a Pydantic v1 `args_schema`.
Regression test: previously this raised `AttributeError` because a Pydantic
v1 model does not implement `model_json_schema`.
"""
class FooV1(BaseModelV1):
a: int
b: str
class SomeTool(BaseTool):
args_schema: type[BaseModelV1] = FooV1
@override
def _run(self, *args: Any, **kwargs: Any) -> str:
return "foo"
tool = SomeTool(name="some_tool", description="some description", args_schema=FooV1)
assert tool.get_input_jsonschema()["properties"] == {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
}
@skip_if_no_pydantic_v1
def test_v1_args_schema_excludes_injected_args() -> None:
"""A Pydantic v1 `args_schema` must not expose injected args via `.args`.
Injected arguments are supplied by the framework rather than the model, so
they must be hidden from the tool-call schema regardless of the `args_schema`
Pydantic version. Previously a v1 `args_schema` bypassed `tool_call_schema`
and leaked injected arguments into `.args`.
"""
class FooV1(BaseModelV1):
a: int
injected: Annotated[str, InjectedToolArg()] = "default"
class SomeTool(BaseTool):
args_schema: type[BaseModelV1] = FooV1
@override
def _run(self, *args: Any, **kwargs: Any) -> str:
return "foo"
tool = SomeTool(name="some_tool", description="some description", args_schema=FooV1)
assert set(tool.args) == {"a"}
assert "injected" not in tool.args
@skip_if_no_pydantic_v1
def test_convert_runnable_to_tool_v1_input_schema() -> None:
"""`convert_runnable_to_tool` works for a runnable with a v1 input schema.
Regression test: deriving the tool schema from the runnable previously
assumed a Pydantic v2 input schema and raised on v1.
"""
class FooV1(BaseModelV1):
a: int
class V1InputRunnable(Runnable[Any, Any]):
@override
def get_input_schema(
self, config: RunnableConfig | None = None
) -> TypeBaseModel:
return FooV1
@override
def invoke(
self,
input: Any,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
return input
tool = convert_runnable_to_tool(V1InputRunnable())
assert set(tool.args) == {"a"}
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
"""This should test that one can type the args schema as a Pydantic model."""
@@ -2784,7 +2877,7 @@ def test_structured_tool_args_schema_dict(caplog: pytest.LogCaptureFixture) -> N
assert _get_tool_call_json_schema(tool) == args_schema
# test that the input schema is the same as the parent (Runnable) input schema
assert (
tool.get_input_schema().model_json_schema()
tool.get_input_jsonschema()
== create_model_v2(
tool.get_name("Input"),
root=tool.InputType,
@@ -2822,7 +2915,7 @@ def test_simple_tool_args_schema_dict() -> None:
assert _get_tool_call_json_schema(tool) == args_schema
# test that the input schema is the same as the parent (Runnable) input schema
assert (
tool.get_input_schema().model_json_schema()
tool.get_input_jsonschema()
== create_model_v2(
tool.get_name("Input"),
root=tool.InputType,

View File

@@ -14,6 +14,8 @@ from langchain_core.utils.pydantic import (
get_fields,
is_basemodel_instance,
is_basemodel_subclass,
model_json_schema,
model_validate,
pre_init,
)
@@ -153,6 +155,56 @@ def test_fields_pydantic_v1_from_2() -> None:
assert fields == {"x": Foo.__fields__["x"]}
def test_model_json_schema_v2() -> None:
class Foo(BaseModel):
x: int
assert model_json_schema(Foo) == Foo.model_json_schema()
@pytest.mark.skipif(
sys.version_info >= (3, 14),
reason="pydantic.v1 namespace not supported with Python 3.14+",
)
def test_model_json_schema_v1() -> None:
class Foo(BaseModelV1):
x: int
assert model_json_schema(Foo) == Foo.schema()
def test_model_json_schema_non_model() -> None:
with pytest.raises(TypeError, match="Expected a Pydantic model"):
model_json_schema(dict) # type: ignore[arg-type]
def test_model_validate_v2() -> None:
class Foo(BaseModel):
x: int
result = model_validate(Foo, {"x": 1})
assert isinstance(result, Foo)
assert result.x == 1
@pytest.mark.skipif(
sys.version_info >= (3, 14),
reason="pydantic.v1 namespace not supported with Python 3.14+",
)
def test_model_validate_v1() -> None:
class Foo(BaseModelV1):
x: int
result = model_validate(Foo, {"x": 1})
assert isinstance(result, Foo)
assert result.x == 1
def test_model_validate_non_model() -> None:
with pytest.raises(TypeError, match="Expected a Pydantic model"):
model_validate(dict, {"x": 1}) # type: ignore[arg-type]
def test_create_model_v2() -> None:
"""Test that create model v2 works as expected."""
with warnings.catch_warnings(record=True) as record:

View File

@@ -40,7 +40,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
@property
def input_keys(self) -> list[str]:
"""Input keys for Hyde's LLM chain."""
return self.llm_chain.input_schema.model_json_schema()["required"]
return self.llm_chain.get_input_jsonschema()["required"]
@property
def output_keys(self) -> list[str]: