mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-30 14:15:49 +00:00
fix(core): fix Pydantic v1 support in tools/runnable (#33698)
`BaseTool.args_schema` is documented as accepting a Pydantic v1 model, but several code paths assumed v2 and raised when handed a v1 schema (e.g. an `AttributeError` from calling `model_json_schema()`/`model_fields` on a v1 model). This affected anyone using a v1 `args_schema`, and anyone composing runnables whose input/output schema is a v1 model. This PR makes the tool/runnable schema-derivation code version-agnostic. ## Type contract `TypeBaseModel` (and `PydanticBaseModel`) now include `pydantic.v1.BaseModel`, so the type honestly reflects what tools and runnables already accept at runtime. The public schema accessors (`Runnable.get_input_schema`/`get_output_schema` and the `input_schema`/`output_schema` properties) return `TypeBaseModel`. ## Version-agnostic helpers Added to `langchain_core.utils.pydantic`, each dispatching on the model's Pydantic version so callers don't have to: - `model_json_schema(model)` — JSON schema for either version. - `model_validate(model, obj)` — validation for either version. - `get_fields(model)` — field map for either version (existing helper, now used consistently). Internally, direct `.model_json_schema()` / `.model_fields` calls are replaced with these helpers (or with `get_input_jsonschema()` / `get_output_jsonschema()`). ## Behavior change worth a close look When deriving a schema from a v1 model (in `RunnableParallel`, `RunnableAssign`, and `RunnableSequence` output schemas), a **required** v1 field is now correctly carried over as required. Previously the v1 path read the field's `default` — which is `None` for a required v1 field — and silently turned required fields into optional/nullable ones; `default_factory` fields were dropped entirely. The new `_get_schema_field_definition` helper translates a v1 `ModelField` faithfully (required → `...`, factory preserved) and dispatches explicitly on the field type. --------- Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
committed by
GitHub
parent
f6d63bc9f3
commit
0392b6bae4
@@ -91,7 +91,7 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_json_schema,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_core.utils.utils import LC_ID_PREFIX, from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -2519,9 +2519,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
)
|
||||
output_parser: JsonOutputToolsParser
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
output_parser = PydanticToolsParser(
|
||||
tools=[cast("TypeBaseModel", schema)], first_tool_only=True
|
||||
)
|
||||
output_parser = PydanticToolsParser(tools=[schema], first_tool_only=True)
|
||||
else:
|
||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
|
||||
@@ -6,7 +6,8 @@ import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import SkipValidation, ValidationError
|
||||
from pydantic import BaseModel, SkipValidation, ValidationError
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
@@ -17,8 +18,6 @@ from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.pydantic import (
|
||||
TypeBaseModel,
|
||||
is_pydantic_v1_subclass,
|
||||
is_pydantic_v2_subclass,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -339,10 +338,10 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
||||
name_dict_v2: dict[str, TypeBaseModel] = {
|
||||
tool.model_config.get("title") or tool.__name__: tool
|
||||
for tool in self.tools
|
||||
if is_pydantic_v2_subclass(tool)
|
||||
if issubclass(tool, BaseModel)
|
||||
}
|
||||
name_dict_v1: dict[str, TypeBaseModel] = {
|
||||
tool.__name__: tool for tool in self.tools if is_pydantic_v1_subclass(tool)
|
||||
tool.__name__: tool for tool in self.tools if issubclass(tool, BaseModelV1)
|
||||
}
|
||||
name_dict: dict[str, TypeBaseModel] = {**name_dict_v2, **name_dict_v1}
|
||||
pydantic_objects = []
|
||||
|
||||
@@ -38,6 +38,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import beta_decorator
|
||||
@@ -97,9 +98,16 @@ from langchain_core.tracers.root_listeners import (
|
||||
)
|
||||
from langchain_core.utils.aiter import aclosing, atee
|
||||
from langchain_core.utils.iter import safetee
|
||||
from langchain_core.utils.pydantic import create_model_v2
|
||||
from langchain_core.utils.pydantic import (
|
||||
TypeBaseModel,
|
||||
create_model_v2,
|
||||
get_fields,
|
||||
model_json_schema,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.fields import ModelField
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
@@ -364,14 +372,14 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
raise TypeError(msg)
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
def input_schema(self) -> TypeBaseModel:
|
||||
"""The type of input this `Runnable` accepts specified as a Pydantic model."""
|
||||
return self.get_input_schema()
|
||||
|
||||
def get_input_schema(
|
||||
self,
|
||||
config: RunnableConfig | None = None,
|
||||
) -> type[BaseModel]:
|
||||
) -> TypeBaseModel:
|
||||
"""Get a Pydantic model that can be used to validate input to the `Runnable`.
|
||||
|
||||
`Runnable` objects that leverage the `configurable_fields` and
|
||||
@@ -437,10 +445,10 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
!!! version-added "Added in `langchain-core` 0.3.0"
|
||||
|
||||
"""
|
||||
return self.get_input_schema(config).model_json_schema()
|
||||
return model_json_schema(self.get_input_schema(config))
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
def output_schema(self) -> TypeBaseModel:
|
||||
"""Output schema.
|
||||
|
||||
The type of output this `Runnable` produces specified as a Pydantic model.
|
||||
@@ -450,7 +458,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
def get_output_schema(
|
||||
self,
|
||||
config: RunnableConfig | None = None,
|
||||
) -> type[BaseModel]:
|
||||
) -> TypeBaseModel:
|
||||
"""Get a Pydantic model that can be used to validate output to the `Runnable`.
|
||||
|
||||
`Runnable` objects that leverage the `configurable_fields` and
|
||||
@@ -516,7 +524,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
!!! version-added "Added in `langchain-core` 0.3.0"
|
||||
|
||||
"""
|
||||
return self.get_output_schema(config).model_json_schema()
|
||||
return model_json_schema(self.get_output_schema(config))
|
||||
|
||||
@property
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
@@ -2953,7 +2961,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
|
||||
def _seq_input_schema(
|
||||
steps: list[Runnable[Any, Any]], config: RunnableConfig | None
|
||||
) -> type[BaseModel]:
|
||||
) -> TypeBaseModel:
|
||||
# Import locally to prevent circular import
|
||||
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
|
||||
RunnableAssign,
|
||||
@@ -2971,7 +2979,7 @@ def _seq_input_schema(
|
||||
"RunnableSequenceInput",
|
||||
field_definitions={
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in next_input_schema.model_fields.items()
|
||||
for k, v in get_fields(next_input_schema).items()
|
||||
if k not in first.mapper.steps__
|
||||
},
|
||||
)
|
||||
@@ -2983,7 +2991,7 @@ def _seq_input_schema(
|
||||
|
||||
def _seq_output_schema(
|
||||
steps: list[Runnable[Any, Any]], config: RunnableConfig | None
|
||||
) -> type[BaseModel]:
|
||||
) -> TypeBaseModel:
|
||||
# Import locally to prevent circular import
|
||||
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
|
||||
RunnableAssign,
|
||||
@@ -3002,12 +3010,12 @@ def _seq_output_schema(
|
||||
"RunnableSequenceOutput",
|
||||
field_definitions={
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in prev_output_schema.model_fields.items()
|
||||
k: _get_schema_field_definition(v)
|
||||
for k, v in get_fields(prev_output_schema).items()
|
||||
},
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in mapper_output_schema.model_fields.items()
|
||||
k: _get_schema_field_definition(v)
|
||||
for k, v in get_fields(mapper_output_schema).items()
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -3019,19 +3027,36 @@ def _seq_output_schema(
|
||||
return create_model_v2(
|
||||
"RunnableSequenceOutput",
|
||||
field_definitions={
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in prev_output_schema.model_fields.items()
|
||||
k: _get_schema_field_definition(v)
|
||||
for k, v in get_fields(prev_output_schema).items()
|
||||
if k in last.keys
|
||||
},
|
||||
)
|
||||
field = prev_output_schema.model_fields[last.keys]
|
||||
field = get_fields(prev_output_schema)[last.keys]
|
||||
return create_model_v2(
|
||||
"RunnableSequenceOutput", root=(field.annotation, field.default)
|
||||
"RunnableSequenceOutput", root=_get_schema_field_definition(field)
|
||||
)
|
||||
|
||||
return last.get_output_schema(config)
|
||||
|
||||
|
||||
def _get_schema_field_definition(field: FieldInfo | ModelField) -> tuple[Any, Any]:
|
||||
"""Convert a Pydantic field to a field definition for `create_model_v2`.
|
||||
|
||||
Handles both Pydantic v2 (`FieldInfo`) and v1 (`ModelField`) fields. v1
|
||||
required fields carry a `None` default with `required=True`, so they must be
|
||||
translated to the `...` sentinel rather than passed through as optional.
|
||||
"""
|
||||
if isinstance(field, FieldInfo):
|
||||
return (field.annotation, field.default)
|
||||
|
||||
if field.required:
|
||||
return (field.annotation, ...)
|
||||
if field.default_factory is not None:
|
||||
return (field.annotation, Field(default_factory=field.default_factory))
|
||||
return (field.annotation, field.default)
|
||||
|
||||
|
||||
_RUNNABLE_SEQUENCE_MIN_STEPS = 2
|
||||
|
||||
|
||||
@@ -3212,7 +3237,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
return self.last.OutputType
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
"""Get the input schema of the `Runnable`.
|
||||
|
||||
Args:
|
||||
@@ -3225,9 +3250,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
return _seq_input_schema(self.steps, config)
|
||||
|
||||
@override
|
||||
def get_output_schema(
|
||||
self, config: RunnableConfig | None = None
|
||||
) -> type[BaseModel]:
|
||||
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
"""Get the output schema of the `Runnable`.
|
||||
|
||||
Args:
|
||||
@@ -3984,7 +4007,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
return Any
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
"""Get the input schema of the `Runnable`.
|
||||
|
||||
Args:
|
||||
@@ -3995,23 +4018,26 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
|
||||
"""
|
||||
if all(
|
||||
s.get_input_schema(config).model_json_schema().get("type", "object")
|
||||
== "object"
|
||||
s.get_input_jsonschema(config).get("type", "object") == "object"
|
||||
for s in self.steps__.values()
|
||||
):
|
||||
for step in self.steps__.values():
|
||||
fields = step.get_input_schema(config).model_fields
|
||||
step_input_schema = step.get_input_schema(config)
|
||||
fields = get_fields(step_input_schema)
|
||||
root_field = fields.get("root")
|
||||
if root_field is not None and root_field.annotation != Any:
|
||||
return super().get_input_schema(config)
|
||||
root_field = fields.get("__root__")
|
||||
if root_field is not None and root_field.annotation != Any:
|
||||
return step_input_schema
|
||||
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model_v2(
|
||||
self.get_name("Input"),
|
||||
field_definitions={
|
||||
k: (v.annotation, v.default)
|
||||
k: _get_schema_field_definition(v)
|
||||
for step in self.steps__.values()
|
||||
for k, v in step.get_input_schema(config).model_fields.items()
|
||||
for k, v in get_fields(step.get_input_schema(config)).items()
|
||||
if k != "__root__"
|
||||
},
|
||||
)
|
||||
@@ -4032,6 +4058,9 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
|
||||
"""
|
||||
fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()}
|
||||
# The return type is narrowed to `type[BaseModel]` (rather than the base
|
||||
# class's `TypeBaseModel`) because this override always builds the schema
|
||||
# with `create_model_v2`, so it is guaranteed to be a Pydantic v2 model.
|
||||
return create_model_v2(self.get_name("Output"), field_definitions=fields)
|
||||
|
||||
@property
|
||||
@@ -4932,7 +4961,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return Any
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
"""The Pydantic schema for the input to this `Runnable`.
|
||||
|
||||
Args:
|
||||
@@ -5925,15 +5954,13 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[
|
||||
)
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
if self.custom_input_type is not None:
|
||||
return super().get_input_schema(config)
|
||||
return self.bound.get_input_schema(merge_configs(self.config, config))
|
||||
|
||||
@override
|
||||
def get_output_schema(
|
||||
self, config: RunnableConfig | None = None
|
||||
) -> type[BaseModel]:
|
||||
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
if self.custom_output_type is not None:
|
||||
return super().get_output_schema(config)
|
||||
return self.bound.get_output_schema(merge_configs(self.config, config))
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
@@ -35,6 +35,7 @@ from langchain_core.runnables.utils import (
|
||||
Output,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
_MIN_BRANCHES = 2
|
||||
|
||||
@@ -154,7 +155,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
runnables = (
|
||||
[self.default]
|
||||
+ [r for _, r in self.branches]
|
||||
@@ -162,10 +163,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
for runnable in runnables:
|
||||
if (
|
||||
runnable.get_input_schema(config).model_json_schema().get("type")
|
||||
is not None
|
||||
):
|
||||
if runnable.get_input_jsonschema(config).get("type") is not None:
|
||||
return runnable.get_input_schema(config)
|
||||
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import (
|
||||
)
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
@@ -44,6 +44,7 @@ from langchain_core.runnables.utils import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
|
||||
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
@@ -90,14 +91,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
return self.default.OutputType
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
runnable, config = self.prepare(config)
|
||||
return runnable.get_input_schema(config)
|
||||
|
||||
@override
|
||||
def get_output_schema(
|
||||
self, config: RunnableConfig | None = None
|
||||
) -> type[BaseModel]:
|
||||
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
runnable, config = self.prepare(config)
|
||||
return runnable.get_output_schema(config)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||
@@ -28,6 +28,7 @@ from langchain_core.runnables.utils import (
|
||||
coro_with_context,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
||||
@@ -118,13 +119,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
return self.runnable.OutputType
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
return self.runnable.get_input_schema(config)
|
||||
|
||||
@override
|
||||
def get_output_schema(
|
||||
self, config: RunnableConfig | None = None
|
||||
) -> type[BaseModel]:
|
||||
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
return self.runnable.get_output_schema(config)
|
||||
|
||||
@property
|
||||
|
||||
@@ -18,13 +18,15 @@ from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.load.serializable import to_json_not_implemented
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
||||
from langchain_core.utils.pydantic import (
|
||||
TypeBaseModel,
|
||||
_IgnoreUnserializable,
|
||||
is_basemodel_subclass,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.runnables.base import Runnable as RunnableType
|
||||
|
||||
|
||||
@@ -97,7 +99,7 @@ class Node(NamedTuple):
|
||||
"""The unique identifier of the node."""
|
||||
name: str
|
||||
"""The name of the node."""
|
||||
data: type[BaseModel] | RunnableType[Any, Any] | None
|
||||
data: TypeBaseModel | RunnableType[Any, Any] | None
|
||||
"""The data of the node."""
|
||||
metadata: dict[str, Any] | None
|
||||
"""Optional metadata for the node. """
|
||||
@@ -177,7 +179,7 @@ class MermaidDrawMethod(Enum):
|
||||
|
||||
def node_data_str(
|
||||
id: str,
|
||||
data: type[BaseModel] | RunnableType[Any, Any] | None,
|
||||
data: TypeBaseModel | RunnableType[Any, Any] | None,
|
||||
) -> str:
|
||||
"""Convert the data of a node to a string.
|
||||
|
||||
@@ -311,7 +313,7 @@ class Graph:
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
data: type[BaseModel] | RunnableType[Any, Any] | None,
|
||||
data: TypeBaseModel | RunnableType[Any, Any] | None,
|
||||
id: str | None = None,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import (
|
||||
Any,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, RootModel
|
||||
from pydantic import RootModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
@@ -19,6 +19,7 @@ from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableParallel,
|
||||
RunnableSerializable,
|
||||
_get_schema_field_definition,
|
||||
)
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
@@ -34,7 +35,7 @@ from langchain_core.runnables.utils import (
|
||||
)
|
||||
from langchain_core.utils.aiter import atee
|
||||
from langchain_core.utils.iter import safetee
|
||||
from langchain_core.utils.pydantic import create_model_v2
|
||||
from langchain_core.utils.pydantic import TypeBaseModel, create_model_v2, get_fields
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
@@ -425,7 +426,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
return super().get_name(suffix, name=name)
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
if not issubclass(map_input_schema, RootModel):
|
||||
# ie. it's a dict
|
||||
@@ -434,9 +435,11 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@override
|
||||
def get_output_schema(
|
||||
self, config: RunnableConfig | None = None
|
||||
) -> type[BaseModel]:
|
||||
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
# The return type stays `TypeBaseModel` (rather than narrowing to
|
||||
# `type[BaseModel]` as `RunnableParallel.get_output_schema` does) because
|
||||
# the fallback branches return the mapper's output schema or delegate to
|
||||
# `super().get_output_schema()`, either of which may be a Pydantic v1 model.
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
map_output_schema = self.mapper.get_output_schema(config)
|
||||
if not issubclass(map_input_schema, RootModel) and not issubclass(
|
||||
@@ -444,11 +447,11 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
):
|
||||
fields = {}
|
||||
|
||||
for name, field_info in map_input_schema.model_fields.items():
|
||||
fields[name] = (field_info.annotation, field_info.default)
|
||||
for name, field_info in get_fields(map_input_schema).items():
|
||||
fields[name] = _get_schema_field_definition(field_info)
|
||||
|
||||
for name, field_info in map_output_schema.model_fields.items():
|
||||
fields[name] = (field_info.annotation, field_info.default)
|
||||
for name, field_info in get_fields(map_output_schema).items():
|
||||
fields[name] = _get_schema_field_definition(field_info)
|
||||
|
||||
return create_model_v2("RunnableAssignOutput", field_definitions=fields)
|
||||
if not issubclass(map_output_schema, RootModel):
|
||||
|
||||
@@ -65,6 +65,7 @@ from langchain_core.utils.pydantic import (
|
||||
is_basemodel_subclass,
|
||||
is_pydantic_v1_subclass,
|
||||
is_pydantic_v2_subclass,
|
||||
model_json_schema,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -267,7 +268,7 @@ def create_schema_from_function(
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
include_injected: bool = True,
|
||||
) -> type[BaseModel]:
|
||||
) -> TypeBaseModel:
|
||||
"""Create a Pydantic schema from a function's signature.
|
||||
|
||||
Args:
|
||||
@@ -572,14 +573,12 @@ class ChildTool(BaseTool):
|
||||
"""
|
||||
if isinstance(self.args_schema, dict):
|
||||
json_schema = self.args_schema
|
||||
elif self.args_schema and issubclass(self.args_schema, BaseModelV1):
|
||||
json_schema = self.args_schema.schema()
|
||||
else:
|
||||
input_schema = self.tool_call_schema
|
||||
if isinstance(input_schema, dict):
|
||||
json_schema = input_schema
|
||||
else:
|
||||
json_schema = input_schema.model_json_schema()
|
||||
json_schema = model_json_schema(input_schema)
|
||||
return cast("dict[str, Any]", json_schema["properties"])
|
||||
|
||||
@property
|
||||
@@ -615,7 +614,7 @@ class ChildTool(BaseTool):
|
||||
# --- Runnable ---
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
"""The tool's input schema.
|
||||
|
||||
Args:
|
||||
@@ -693,6 +692,7 @@ class ChildTool(BaseTool):
|
||||
if input_args is not None:
|
||||
if isinstance(input_args, dict):
|
||||
return tool_input
|
||||
result: BaseModel | BaseModelV1
|
||||
if issubclass(input_args, BaseModel):
|
||||
# Check args_schema for InjectedToolCallId
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
@@ -707,8 +707,9 @@ class ChildTool(BaseTool):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
result = input_args.model_validate(tool_input)
|
||||
result_dict = result.model_dump()
|
||||
result_v2 = input_args.model_validate(tool_input)
|
||||
result_dict = result_v2.model_dump()
|
||||
result = result_v2
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
# Check args_schema for InjectedToolCallId
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
@@ -723,8 +724,9 @@ class ChildTool(BaseTool):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
result = input_args.parse_obj(tool_input)
|
||||
result_dict = result.dict()
|
||||
result_v1 = input_args.parse_obj(tool_input)
|
||||
result_dict = result_v1.dict()
|
||||
result = result_v1
|
||||
else:
|
||||
msg = (
|
||||
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
||||
|
||||
@@ -11,6 +11,7 @@ from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools.base import ArgsSchema, BaseTool
|
||||
from langchain_core.tools.simple import Tool
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
|
||||
@overload
|
||||
@@ -275,7 +276,7 @@ def tool(
|
||||
if isinstance(dec_func, Runnable):
|
||||
runnable = dec_func
|
||||
|
||||
if runnable.input_schema.model_json_schema().get("type") != "object":
|
||||
if runnable.get_input_jsonschema().get("type") != "object":
|
||||
msg = "Runnable must have an object schema."
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -394,7 +395,7 @@ def tool(
|
||||
|
||||
def _get_description_from_runnable(runnable: Runnable[Any, Any]) -> str:
|
||||
"""Generate a placeholder description of a `Runnable`."""
|
||||
input_schema = runnable.input_schema.model_json_schema()
|
||||
input_schema = runnable.get_input_jsonschema()
|
||||
return f"Takes {input_schema}."
|
||||
|
||||
|
||||
@@ -420,7 +421,7 @@ def _get_schema_from_runnable_and_arg_types(
|
||||
|
||||
def convert_runnable_to_tool(
|
||||
runnable: Runnable[Any, Any],
|
||||
args_schema: type[BaseModel] | None = None,
|
||||
args_schema: TypeBaseModel | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
@@ -443,7 +444,7 @@ def convert_runnable_to_tool(
|
||||
description = description or _get_description_from_runnable(runnable)
|
||||
name = name or runnable.get_name()
|
||||
|
||||
schema = runnable.input_schema.model_json_schema()
|
||||
schema = runnable.get_input_jsonschema()
|
||||
if schema.get("type") == "string":
|
||||
return Tool(
|
||||
name=name,
|
||||
|
||||
@@ -69,8 +69,8 @@ PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor
|
||||
IS_PYDANTIC_V1 = False
|
||||
IS_PYDANTIC_V2 = True
|
||||
|
||||
PydanticBaseModel = BaseModel
|
||||
TypeBaseModel = type[BaseModel]
|
||||
PydanticBaseModel = BaseModel | BaseModelV1
|
||||
TypeBaseModel = type[BaseModel] | type[BaseModelV1]
|
||||
|
||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||
|
||||
@@ -287,7 +287,7 @@ def _create_subset_model(
|
||||
*,
|
||||
descriptions: dict[str, str] | None = None,
|
||||
fn_description: str | None = None,
|
||||
) -> type[BaseModel]:
|
||||
) -> TypeBaseModel:
|
||||
"""Create subset model using the same pydantic version as the input model.
|
||||
|
||||
Returns:
|
||||
@@ -347,6 +347,49 @@ def get_fields(
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def model_json_schema(model: TypeBaseModel) -> dict[str, Any]:
|
||||
"""Return the JSON schema of a Pydantic model class of either major version.
|
||||
|
||||
Dispatches to the correct method for Pydantic v1 (`schema`) or v2
|
||||
(`model_json_schema`), so callers holding a `TypeBaseModel` don't have to
|
||||
branch on the model's version themselves.
|
||||
|
||||
Args:
|
||||
model: The Pydantic model class.
|
||||
|
||||
Raises:
|
||||
TypeError: If the model is not a Pydantic model class.
|
||||
"""
|
||||
if issubclass(model, BaseModel):
|
||||
return model.model_json_schema()
|
||||
if issubclass(model, BaseModelV1):
|
||||
return model.schema()
|
||||
msg = f"Expected a Pydantic model. Got {model}"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def model_validate(model: TypeBaseModel, obj: Any) -> PydanticBaseModel:
|
||||
"""Validate `obj` against a Pydantic model class of either major version.
|
||||
|
||||
Dispatches to the correct method for Pydantic v1 (`parse_obj`) or v2
|
||||
(`model_validate`), so callers holding a `TypeBaseModel` don't have to
|
||||
branch on the model's version themselves.
|
||||
|
||||
Args:
|
||||
model: The Pydantic model class to validate against.
|
||||
obj: The object to validate.
|
||||
|
||||
Raises:
|
||||
TypeError: If the model is not a Pydantic model class.
|
||||
"""
|
||||
if issubclass(model, BaseModel):
|
||||
return model.model_validate(obj)
|
||||
if issubclass(model, BaseModelV1):
|
||||
return model.parse_obj(obj)
|
||||
msg = f"Expected a Pydantic model. Got {model}"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
_SchemaConfig = ConfigDict(
|
||||
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
|
||||
)
|
||||
|
||||
@@ -175,7 +175,7 @@ def test_pydantic_output_parser_type_inference() -> None:
|
||||
# Ignoring mypy error that appears in python 3.8, but not 3.11.
|
||||
# This seems to be functionally correct, so we'll ignore the error.
|
||||
pydantic_parser = PydanticOutputParser[SampleModel](pydantic_object=SampleModel)
|
||||
schema = pydantic_parser.get_output_schema().model_json_schema()
|
||||
schema = pydantic_parser.get_output_jsonschema()
|
||||
|
||||
assert schema == {
|
||||
"properties": {
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
import sys
|
||||
from inspect import isclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
# pydantic.v1 models cannot be exercised under the compatibility shim on
|
||||
# Python 3.14+, so v1-specific tests are skipped there.
|
||||
skip_if_no_pydantic_v1 = pytest.mark.skipif(
|
||||
sys.version_info >= (3, 14),
|
||||
reason="pydantic.v1 namespace not supported with Python 3.14+",
|
||||
)
|
||||
|
||||
|
||||
# Function to replace allOf with $ref
|
||||
def replace_all_of_with_ref(schema: Any) -> None:
|
||||
|
||||
@@ -19,7 +19,10 @@ from uuid import UUID
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import Field as FieldV1
|
||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from typing_extensions import TypedDict, override
|
||||
@@ -90,9 +93,17 @@ from langchain_core.tracers import (
|
||||
)
|
||||
from langchain_core.tracers._compat import pydantic_copy
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
||||
from langchain_core.utils.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
TypeBaseModel,
|
||||
model_validate,
|
||||
)
|
||||
from langchain_core.version import VERSION
|
||||
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
|
||||
from tests.unit_tests.pydantic_utils import (
|
||||
_normalize_schema,
|
||||
_schema,
|
||||
skip_if_no_pydantic_v1,
|
||||
)
|
||||
from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk
|
||||
|
||||
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
|
||||
@@ -499,7 +510,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
foo_ = RunnableLambda(foo)
|
||||
|
||||
assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == {
|
||||
assert foo_.assign(bar=lambda _: "foo").get_output_jsonschema() == {
|
||||
"properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
|
||||
"required": ["root", "bar"],
|
||||
"title": "RunnableAssignOutput",
|
||||
@@ -5820,6 +5831,167 @@ def test_runnable_typed_dict_schema() -> None:
|
||||
other=other_runnable,
|
||||
)
|
||||
assert (
|
||||
repr(parallel.input_schema.model_validate({"foo": "Y", "bar": "Z"}))
|
||||
repr(model_validate(parallel.input_schema, {"foo": "Y", "bar": "Z"}))
|
||||
== "RunnableParallel<foo,other>Input(root={'foo': 'Y', 'bar': 'Z'})"
|
||||
)
|
||||
|
||||
|
||||
class _RunnableWithInputSchema(Runnable[Any, Any]):
|
||||
def __init__(self, input_schema: TypeBaseModel) -> None:
|
||||
self._input_schema = input_schema
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
return self._input_schema
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
|
||||
_ = config
|
||||
return self._input_schema
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return input
|
||||
|
||||
|
||||
@skip_if_no_pydantic_v1
|
||||
def test_runnable_parallel_preserves_required_v1_input_fields() -> None:
|
||||
class InputModel(BaseModelV1):
|
||||
a: int
|
||||
b: int = 2
|
||||
|
||||
parallel = RunnableParallel(foo=_RunnableWithInputSchema(InputModel))
|
||||
|
||||
schema = cast("type[BaseModel]", parallel.input_schema)
|
||||
assert schema.model_json_schema()["required"] == ["a"]
|
||||
with pytest.raises(ValidationError):
|
||||
schema.model_validate({})
|
||||
|
||||
model = schema.model_validate({"a": 1})
|
||||
assert model.model_dump() == {"a": 1, "b": 2}
|
||||
|
||||
|
||||
@skip_if_no_pydantic_v1
|
||||
def test_runnable_parallel_uses_base_schema_for_v1_root_model() -> None:
|
||||
class InputModel(BaseModelV1):
|
||||
__root__: dict[str, int]
|
||||
|
||||
parallel = RunnableParallel(foo=_RunnableWithInputSchema(InputModel))
|
||||
|
||||
schema = parallel.input_schema
|
||||
assert schema is InputModel
|
||||
assert schema.schema() == {
|
||||
"additionalProperties": {"type": "integer"},
|
||||
"title": "InputModel",
|
||||
"type": "object",
|
||||
}
|
||||
assert schema.parse_obj({"a": 1}).__root__ == {"a": 1}
|
||||
with pytest.raises(ValidationErrorV1):
|
||||
schema.parse_obj({"a": "not an int"})
|
||||
|
||||
|
||||
@skip_if_no_pydantic_v1
|
||||
def test_runnable_sequence_v1_input_schema() -> None:
|
||||
"""A `RunnableSequence` exposes a Pydantic v1 first-step input schema.
|
||||
|
||||
Regression test: deriving the sequence schema previously assumed Pydantic v2.
|
||||
"""
|
||||
|
||||
class InputModel(BaseModelV1):
|
||||
a: int
|
||||
|
||||
sequence = _RunnableWithInputSchema(InputModel) | RunnableLambda(lambda x: x)
|
||||
|
||||
assert sequence.get_input_jsonschema()["properties"] == {
|
||||
"a": {"title": "A", "type": "integer"}
|
||||
}
|
||||
|
||||
|
||||
@skip_if_no_pydantic_v1
|
||||
def test_runnable_branch_v1_input_schema() -> None:
|
||||
"""A `RunnableBranch` exposes a Pydantic v1 input schema.
|
||||
|
||||
Regression test: `get_input_schema` previously assumed Pydantic v2 when
|
||||
inspecting each branch's schema.
|
||||
"""
|
||||
|
||||
class InputModel(BaseModelV1):
|
||||
a: int
|
||||
|
||||
branch = RunnableBranch(
|
||||
(lambda _: True, _RunnableWithInputSchema(InputModel)),
|
||||
_RunnableWithInputSchema(InputModel),
|
||||
)
|
||||
|
||||
assert branch.get_input_jsonschema()["properties"] == {
|
||||
"a": {"title": "A", "type": "integer"}
|
||||
}
|
||||
|
||||
|
||||
@skip_if_no_pydantic_v1
|
||||
def test_runnable_parallel_preserves_v1_default_factory() -> None:
|
||||
"""A v1 `default_factory` field keeps its factory in the derived schema.
|
||||
|
||||
Regression test for `_get_schema_field_definition`: without the dedicated
|
||||
`default_factory` branch the factory is dropped and the field defaults to
|
||||
`None` instead of producing the factory's value.
|
||||
"""
|
||||
|
||||
class InputModel(BaseModelV1):
|
||||
a: int
|
||||
items: list[int] = FieldV1(default_factory=list)
|
||||
|
||||
parallel = RunnableParallel(foo=_RunnableWithInputSchema(InputModel))
|
||||
|
||||
schema = cast("type[BaseModel]", parallel.input_schema)
|
||||
# The factory field is optional, so only `a` is required.
|
||||
assert schema.model_json_schema()["required"] == ["a"]
|
||||
|
||||
# The factory runs when omitted, yielding `[]` rather than `None`.
|
||||
assert schema.model_validate({"a": 1}).model_dump() == {"a": 1, "items": []}
|
||||
|
||||
|
||||
@skip_if_no_pydantic_v1
|
||||
def test_runnable_sequence_v1_output_schema_with_assign() -> None:
|
||||
"""A sequence ending in `RunnableAssign` derives a v1 upstream output schema.
|
||||
|
||||
Regression test: `_seq_output_schema` previously read `.model_fields` on the
|
||||
upstream schema (v2-only) and translated required v1 fields as optional.
|
||||
"""
|
||||
|
||||
class InputModel(BaseModelV1):
|
||||
a: int
|
||||
|
||||
sequence = _RunnableWithInputSchema(InputModel) | RunnableAssign(
|
||||
RunnableParallel(bar=RunnableLambda(lambda _: "bar"))
|
||||
)
|
||||
|
||||
schema = sequence.get_output_jsonschema()
|
||||
assert set(schema["properties"]) == {"a", "bar"}
|
||||
# The required v1 field survives as required rather than becoming optional.
|
||||
assert "a" in schema["required"]
|
||||
|
||||
|
||||
@skip_if_no_pydantic_v1
|
||||
def test_runnable_sequence_v1_output_schema_with_pick() -> None:
|
||||
"""A sequence ending in `RunnablePick` derives a v1 upstream output schema.
|
||||
|
||||
Regression test: `_seq_output_schema` previously read `.model_fields` on the
|
||||
upstream schema (v2-only) for the `RunnablePick` branch.
|
||||
"""
|
||||
|
||||
class InputModel(BaseModelV1):
|
||||
a: int
|
||||
b: int
|
||||
|
||||
sequence = _RunnableWithInputSchema(InputModel) | RunnablePick(["a"])
|
||||
|
||||
schema = sequence.get_output_jsonschema()
|
||||
assert set(schema["properties"]) == {"a"}
|
||||
assert "a" in schema["required"]
|
||||
|
||||
@@ -40,6 +40,7 @@ from langchain_core.messages import ToolCall, ToolMessage
|
||||
from langchain_core.messages.tool import ToolOutputMixin
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
ensure_config,
|
||||
@@ -49,6 +50,7 @@ from langchain_core.tools import (
|
||||
StructuredTool,
|
||||
Tool,
|
||||
ToolException,
|
||||
convert_runnable_to_tool,
|
||||
tool,
|
||||
)
|
||||
from langchain_core.tools.base import (
|
||||
@@ -68,11 +70,16 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import (
|
||||
TypeBaseModel,
|
||||
_create_subset_model,
|
||||
create_model_v2,
|
||||
)
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
|
||||
from tests.unit_tests.pydantic_utils import (
|
||||
_normalize_schema,
|
||||
_schema,
|
||||
skip_if_no_pydantic_v1,
|
||||
)
|
||||
|
||||
try:
|
||||
from langgraph.prebuilt import ToolRuntime # type: ignore[import-not-found]
|
||||
@@ -1964,7 +1971,7 @@ def test_tool_inherited_injected_arg() -> None:
|
||||
return y
|
||||
|
||||
tool_ = InheritedInjectedArgTool()
|
||||
assert tool_.get_input_schema().model_json_schema() == {
|
||||
assert tool_.get_input_jsonschema() == {
|
||||
"title": "FooSchema", # Matches the title from the provided schema
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
@@ -2138,7 +2145,7 @@ def test_args_schema_explicitly_typed() -> None:
|
||||
|
||||
tool = SomeTool(name="some_tool", description="some description")
|
||||
|
||||
assert tool.get_input_schema().model_json_schema() == {
|
||||
assert tool.get_input_jsonschema() == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "string"},
|
||||
@@ -2160,6 +2167,92 @@ def test_args_schema_explicitly_typed() -> None:
|
||||
}
|
||||
|
||||
|
||||
@skip_if_no_pydantic_v1
|
||||
def test_get_input_jsonschema_v1_args_schema() -> None:
|
||||
"""`get_input_jsonschema()` works for a tool with a Pydantic v1 `args_schema`.
|
||||
|
||||
Regression test: previously this raised `AttributeError` because a Pydantic
|
||||
v1 model does not implement `model_json_schema`.
|
||||
"""
|
||||
|
||||
class FooV1(BaseModelV1):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
class SomeTool(BaseTool):
|
||||
args_schema: type[BaseModelV1] = FooV1
|
||||
|
||||
@override
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
return "foo"
|
||||
|
||||
tool = SomeTool(name="some_tool", description="some description", args_schema=FooV1)
|
||||
|
||||
assert tool.get_input_jsonschema()["properties"] == {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "string"},
|
||||
}
|
||||
|
||||
|
||||
@skip_if_no_pydantic_v1
|
||||
def test_v1_args_schema_excludes_injected_args() -> None:
|
||||
"""A Pydantic v1 `args_schema` must not expose injected args via `.args`.
|
||||
|
||||
Injected arguments are supplied by the framework rather than the model, so
|
||||
they must be hidden from the tool-call schema regardless of the `args_schema`
|
||||
Pydantic version. Previously a v1 `args_schema` bypassed `tool_call_schema`
|
||||
and leaked injected arguments into `.args`.
|
||||
"""
|
||||
|
||||
class FooV1(BaseModelV1):
|
||||
a: int
|
||||
injected: Annotated[str, InjectedToolArg()] = "default"
|
||||
|
||||
class SomeTool(BaseTool):
|
||||
args_schema: type[BaseModelV1] = FooV1
|
||||
|
||||
@override
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
return "foo"
|
||||
|
||||
tool = SomeTool(name="some_tool", description="some description", args_schema=FooV1)
|
||||
|
||||
assert set(tool.args) == {"a"}
|
||||
assert "injected" not in tool.args
|
||||
|
||||
|
||||
@skip_if_no_pydantic_v1
|
||||
def test_convert_runnable_to_tool_v1_input_schema() -> None:
|
||||
"""`convert_runnable_to_tool` works for a runnable with a v1 input schema.
|
||||
|
||||
Regression test: deriving the tool schema from the runnable previously
|
||||
assumed a Pydantic v2 input schema and raised on v1.
|
||||
"""
|
||||
|
||||
class FooV1(BaseModelV1):
|
||||
a: int
|
||||
|
||||
class V1InputRunnable(Runnable[Any, Any]):
|
||||
@override
|
||||
def get_input_schema(
|
||||
self, config: RunnableConfig | None = None
|
||||
) -> TypeBaseModel:
|
||||
return FooV1
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return input
|
||||
|
||||
tool = convert_runnable_to_tool(V1InputRunnable())
|
||||
|
||||
assert set(tool.args) == {"a"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
||||
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
|
||||
"""This should test that one can type the args schema as a Pydantic model."""
|
||||
@@ -2784,7 +2877,7 @@ def test_structured_tool_args_schema_dict(caplog: pytest.LogCaptureFixture) -> N
|
||||
assert _get_tool_call_json_schema(tool) == args_schema
|
||||
# test that the input schema is the same as the parent (Runnable) input schema
|
||||
assert (
|
||||
tool.get_input_schema().model_json_schema()
|
||||
tool.get_input_jsonschema()
|
||||
== create_model_v2(
|
||||
tool.get_name("Input"),
|
||||
root=tool.InputType,
|
||||
@@ -2822,7 +2915,7 @@ def test_simple_tool_args_schema_dict() -> None:
|
||||
assert _get_tool_call_json_schema(tool) == args_schema
|
||||
# test that the input schema is the same as the parent (Runnable) input schema
|
||||
assert (
|
||||
tool.get_input_schema().model_json_schema()
|
||||
tool.get_input_jsonschema()
|
||||
== create_model_v2(
|
||||
tool.get_name("Input"),
|
||||
root=tool.InputType,
|
||||
|
||||
@@ -14,6 +14,8 @@ from langchain_core.utils.pydantic import (
|
||||
get_fields,
|
||||
is_basemodel_instance,
|
||||
is_basemodel_subclass,
|
||||
model_json_schema,
|
||||
model_validate,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
@@ -153,6 +155,56 @@ def test_fields_pydantic_v1_from_2() -> None:
|
||||
assert fields == {"x": Foo.__fields__["x"]}
|
||||
|
||||
|
||||
def test_model_json_schema_v2() -> None:
|
||||
class Foo(BaseModel):
|
||||
x: int
|
||||
|
||||
assert model_json_schema(Foo) == Foo.model_json_schema()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info >= (3, 14),
|
||||
reason="pydantic.v1 namespace not supported with Python 3.14+",
|
||||
)
|
||||
def test_model_json_schema_v1() -> None:
|
||||
class Foo(BaseModelV1):
|
||||
x: int
|
||||
|
||||
assert model_json_schema(Foo) == Foo.schema()
|
||||
|
||||
|
||||
def test_model_json_schema_non_model() -> None:
|
||||
with pytest.raises(TypeError, match="Expected a Pydantic model"):
|
||||
model_json_schema(dict) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_model_validate_v2() -> None:
|
||||
class Foo(BaseModel):
|
||||
x: int
|
||||
|
||||
result = model_validate(Foo, {"x": 1})
|
||||
assert isinstance(result, Foo)
|
||||
assert result.x == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info >= (3, 14),
|
||||
reason="pydantic.v1 namespace not supported with Python 3.14+",
|
||||
)
|
||||
def test_model_validate_v1() -> None:
|
||||
class Foo(BaseModelV1):
|
||||
x: int
|
||||
|
||||
result = model_validate(Foo, {"x": 1})
|
||||
assert isinstance(result, Foo)
|
||||
assert result.x == 1
|
||||
|
||||
|
||||
def test_model_validate_non_model() -> None:
|
||||
with pytest.raises(TypeError, match="Expected a Pydantic model"):
|
||||
model_validate(dict, {"x": 1}) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_create_model_v2() -> None:
|
||||
"""Test that create model v2 works as expected."""
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
|
||||
@@ -40,7 +40,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
@property
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Input keys for Hyde's LLM chain."""
|
||||
return self.llm_chain.input_schema.model_json_schema()["required"]
|
||||
return self.llm_chain.get_input_jsonschema()["required"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> list[str]:
|
||||
|
||||
Reference in New Issue
Block a user