diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 7c21d12110f..a0e8f42bef7 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -91,7 +91,7 @@ from langchain_core.utils.function_calling import ( convert_to_json_schema, convert_to_openai_tool, ) -from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass +from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.utils import LC_ID_PREFIX, from_env if TYPE_CHECKING: @@ -2519,9 +2519,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): ) output_parser: JsonOutputToolsParser if isinstance(schema, type) and is_basemodel_subclass(schema): - output_parser = PydanticToolsParser( - tools=[cast("TypeBaseModel", schema)], first_tool_only=True - ) + output_parser = PydanticToolsParser(tools=[schema], first_tool_only=True) else: key_name = convert_to_openai_tool(schema)["function"]["name"] output_parser = JsonOutputKeyToolsParser( diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index 78be65ecef2..c4b404b97a5 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -6,7 +6,8 @@ import logging from json import JSONDecodeError from typing import Annotated, Any -from pydantic import SkipValidation, ValidationError +from pydantic import BaseModel, SkipValidation, ValidationError +from pydantic.v1 import BaseModel as BaseModelV1 from langchain_core.exceptions import OutputParserException from langchain_core.messages import AIMessage, InvalidToolCall @@ -17,8 +18,6 @@ from langchain_core.outputs import ChatGeneration, Generation from langchain_core.utils.json import parse_partial_json from langchain_core.utils.pydantic import ( TypeBaseModel, - is_pydantic_v1_subclass, - is_pydantic_v2_subclass, ) logger = logging.getLogger(__name__) @@ -339,10 +338,10 @@ class PydanticToolsParser(JsonOutputToolsParser): name_dict_v2: dict[str, TypeBaseModel] = { tool.model_config.get("title") or tool.__name__: tool for tool in self.tools - if is_pydantic_v2_subclass(tool) + if issubclass(tool, BaseModel) } name_dict_v1: dict[str, TypeBaseModel] = { - tool.__name__: tool for tool in self.tools if is_pydantic_v1_subclass(tool) + tool.__name__: tool for tool in self.tools if issubclass(tool, BaseModelV1) } name_dict: dict[str, TypeBaseModel] = {**name_dict_v2, **name_dict_v1} pydantic_objects = [] diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index f6b47558aff..231553dc81a 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -38,6 +38,7 @@ from typing import ( ) from pydantic import BaseModel, ConfigDict, Field, RootModel +from pydantic.fields import FieldInfo from typing_extensions import override from langchain_core._api import beta_decorator @@ -97,9 +98,16 @@ from langchain_core.tracers.root_listeners import ( ) from langchain_core.utils.aiter import aclosing, atee from langchain_core.utils.iter import safetee -from langchain_core.utils.pydantic import create_model_v2 +from langchain_core.utils.pydantic import ( + TypeBaseModel, + create_model_v2, + get_fields, + model_json_schema, +) if TYPE_CHECKING: + from pydantic.v1.fields import ModelField + from langchain_core.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -364,14 +372,14 @@ class Runnable(ABC, Generic[Input, Output]): raise TypeError(msg) @property - def input_schema(self) -> type[BaseModel]: + def input_schema(self) -> TypeBaseModel: """The type of input this `Runnable` accepts specified as a Pydantic model.""" return self.get_input_schema() def get_input_schema( self, config: RunnableConfig | None = None, - ) -> type[BaseModel]: + ) -> TypeBaseModel: """Get a Pydantic model that can be used to validate input to the `Runnable`. `Runnable` objects that leverage the `configurable_fields` and @@ -437,10 +445,10 @@ class Runnable(ABC, Generic[Input, Output]): !!! version-added "Added in `langchain-core` 0.3.0" """ - return self.get_input_schema(config).model_json_schema() + return model_json_schema(self.get_input_schema(config)) @property - def output_schema(self) -> type[BaseModel]: + def output_schema(self) -> TypeBaseModel: """Output schema. The type of output this `Runnable` produces specified as a Pydantic model. @@ -450,7 +458,7 @@ class Runnable(ABC, Generic[Input, Output]): def get_output_schema( self, config: RunnableConfig | None = None, - ) -> type[BaseModel]: + ) -> TypeBaseModel: """Get a Pydantic model that can be used to validate output to the `Runnable`. `Runnable` objects that leverage the `configurable_fields` and @@ -516,7 +524,7 @@ class Runnable(ABC, Generic[Input, Output]): !!! version-added "Added in `langchain-core` 0.3.0" """ - return self.get_output_schema(config).model_json_schema() + return model_json_schema(self.get_output_schema(config)) @property def config_specs(self) -> list[ConfigurableFieldSpec]: @@ -2953,7 +2961,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): def _seq_input_schema( steps: list[Runnable[Any, Any]], config: RunnableConfig | None -) -> type[BaseModel]: +) -> TypeBaseModel: # Import locally to prevent circular import from langchain_core.runnables.passthrough import ( # noqa: PLC0415 RunnableAssign, @@ -2971,7 +2979,7 @@ def _seq_input_schema( "RunnableSequenceInput", field_definitions={ k: (v.annotation, v.default) - for k, v in next_input_schema.model_fields.items() + for k, v in get_fields(next_input_schema).items() if k not in first.mapper.steps__ }, ) @@ -2983,7 +2991,7 @@ def _seq_input_schema( def _seq_output_schema( steps: list[Runnable[Any, Any]], config: RunnableConfig | None -) -> type[BaseModel]: +) -> TypeBaseModel: # Import locally to prevent circular import from langchain_core.runnables.passthrough import ( # noqa: PLC0415 RunnableAssign, @@ -3002,12 +3010,12 @@ def _seq_output_schema( "RunnableSequenceOutput", field_definitions={ **{ - k: (v.annotation, v.default) - for k, v in prev_output_schema.model_fields.items() + k: _get_schema_field_definition(v) + for k, v in get_fields(prev_output_schema).items() }, **{ - k: (v.annotation, v.default) - for k, v in mapper_output_schema.model_fields.items() + k: _get_schema_field_definition(v) + for k, v in get_fields(mapper_output_schema).items() }, }, ) @@ -3019,19 +3027,36 @@ def _seq_output_schema( return create_model_v2( "RunnableSequenceOutput", field_definitions={ - k: (v.annotation, v.default) - for k, v in prev_output_schema.model_fields.items() + k: _get_schema_field_definition(v) + for k, v in get_fields(prev_output_schema).items() if k in last.keys }, ) - field = prev_output_schema.model_fields[last.keys] + field = get_fields(prev_output_schema)[last.keys] return create_model_v2( - "RunnableSequenceOutput", root=(field.annotation, field.default) + "RunnableSequenceOutput", root=_get_schema_field_definition(field) ) return last.get_output_schema(config) +def _get_schema_field_definition(field: FieldInfo | ModelField) -> tuple[Any, Any]: + """Convert a Pydantic field to a field definition for `create_model_v2`. + + Handles both Pydantic v2 (`FieldInfo`) and v1 (`ModelField`) fields. v1 + required fields carry a `None` default with `required=True`, so they must be + translated to the `...` sentinel rather than passed through as optional. + """ + if isinstance(field, FieldInfo): + return (field.annotation, field.default) + + if field.required: + return (field.annotation, ...) + if field.default_factory is not None: + return (field.annotation, Field(default_factory=field.default_factory)) + return (field.annotation, field.default) + + _RUNNABLE_SEQUENCE_MIN_STEPS = 2 @@ -3212,7 +3237,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): return self.last.OutputType @override - def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: """Get the input schema of the `Runnable`. Args: @@ -3225,9 +3250,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): return _seq_input_schema(self.steps, config) @override - def get_output_schema( - self, config: RunnableConfig | None = None - ) -> type[BaseModel]: + def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: """Get the output schema of the `Runnable`. Args: @@ -3984,7 +4007,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): return Any @override - def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: """Get the input schema of the `Runnable`. Args: @@ -3995,23 +4018,26 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): """ if all( - s.get_input_schema(config).model_json_schema().get("type", "object") - == "object" + s.get_input_jsonschema(config).get("type", "object") == "object" for s in self.steps__.values() ): for step in self.steps__.values(): - fields = step.get_input_schema(config).model_fields + step_input_schema = step.get_input_schema(config) + fields = get_fields(step_input_schema) root_field = fields.get("root") if root_field is not None and root_field.annotation != Any: return super().get_input_schema(config) + root_field = fields.get("__root__") + if root_field is not None and root_field.annotation != Any: + return step_input_schema # This is correct, but pydantic typings/mypy don't think so. return create_model_v2( self.get_name("Input"), field_definitions={ - k: (v.annotation, v.default) + k: _get_schema_field_definition(v) for step in self.steps__.values() - for k, v in step.get_input_schema(config).model_fields.items() + for k, v in get_fields(step.get_input_schema(config)).items() if k != "__root__" }, ) @@ -4032,6 +4058,9 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): """ fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()} + # The return type is narrowed to `type[BaseModel]` (rather than the base + # class's `TypeBaseModel`) because this override always builds the schema + # with `create_model_v2`, so it is guaranteed to be a Pydantic v2 model. return create_model_v2(self.get_name("Output"), field_definitions=fields) @property @@ -4932,7 +4961,7 @@ class RunnableLambda(Runnable[Input, Output]): return Any @override - def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: """The Pydantic schema for the input to this `Runnable`. Args: @@ -5925,15 +5954,13 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ ) @override - def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: if self.custom_input_type is not None: return super().get_input_schema(config) return self.bound.get_input_schema(merge_configs(self.config, config)) @override - def get_output_schema( - self, config: RunnableConfig | None = None - ) -> type[BaseModel]: + def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: if self.custom_output_type is not None: return super().get_output_schema(config) return self.bound.get_output_schema(merge_configs(self.config, config)) diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index 277a4a36073..1c14476912d 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -13,7 +13,7 @@ from typing import ( cast, ) -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict from typing_extensions import override from langchain_core.runnables.base import ( @@ -35,6 +35,7 @@ from langchain_core.runnables.utils import ( Output, get_unique_config_specs, ) +from langchain_core.utils.pydantic import TypeBaseModel _MIN_BRANCHES = 2 @@ -154,7 +155,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): return ["langchain", "schema", "runnable"] @override - def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: runnables = ( [self.default] + [r for _, r in self.branches] @@ -162,10 +163,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): ) for runnable in runnables: - if ( - runnable.get_input_schema(config).model_json_schema().get("type") - is not None - ): + if runnable.get_input_jsonschema(config).get("type") is not None: return runnable.get_input_schema(config) return super().get_input_schema(config) diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index a03108850fa..415c30d4d28 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -19,7 +19,7 @@ from typing import ( ) from weakref import WeakValueDictionary -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict from typing_extensions import override from langchain_core.runnables.base import Runnable, RunnableSerializable @@ -44,6 +44,7 @@ from langchain_core.runnables.utils import ( if TYPE_CHECKING: from langchain_core.runnables.graph import Graph + from langchain_core.utils.pydantic import TypeBaseModel class DynamicRunnable(RunnableSerializable[Input, Output]): @@ -90,14 +91,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): return self.default.OutputType @override - def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: runnable, config = self.prepare(config) return runnable.get_input_schema(config) @override - def get_output_schema( - self, config: RunnableConfig | None = None - ) -> type[BaseModel]: + def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: runnable, config = self.prepare(config) return runnable.get_output_schema(config) diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index 16963722b26..7e8235a4d42 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Iterator, Sequence from functools import wraps from typing import TYPE_CHECKING, Any, cast -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict from typing_extensions import override from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager @@ -28,6 +28,7 @@ from langchain_core.runnables.utils import ( coro_with_context, get_unique_config_specs, ) +from langchain_core.utils.pydantic import TypeBaseModel if TYPE_CHECKING: from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun @@ -118,13 +119,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): return self.runnable.OutputType @override - def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: return self.runnable.get_input_schema(config) @override - def get_output_schema( - self, config: RunnableConfig | None = None - ) -> type[BaseModel]: + def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: return self.runnable.get_output_schema(config) @property diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 4b5bfdfb836..6886d7b6611 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -18,13 +18,15 @@ from uuid import UUID, uuid4 from langchain_core.load.serializable import to_json_not_implemented from langchain_core.runnables.base import Runnable, RunnableSerializable -from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass +from langchain_core.utils.pydantic import ( + TypeBaseModel, + _IgnoreUnserializable, + is_basemodel_subclass, +) if TYPE_CHECKING: from collections.abc import Callable, Sequence - from pydantic import BaseModel - from langchain_core.runnables.base import Runnable as RunnableType @@ -97,7 +99,7 @@ class Node(NamedTuple): """The unique identifier of the node.""" name: str """The name of the node.""" - data: type[BaseModel] | RunnableType[Any, Any] | None + data: TypeBaseModel | RunnableType[Any, Any] | None """The data of the node.""" metadata: dict[str, Any] | None """Optional metadata for the node. """ @@ -177,7 +179,7 @@ class MermaidDrawMethod(Enum): def node_data_str( id: str, - data: type[BaseModel] | RunnableType[Any, Any] | None, + data: TypeBaseModel | RunnableType[Any, Any] | None, ) -> str: """Convert the data of a node to a string. @@ -311,7 +313,7 @@ class Graph: def add_node( self, - data: type[BaseModel] | RunnableType[Any, Any] | None, + data: TypeBaseModel | RunnableType[Any, Any] | None, id: str | None = None, *, metadata: dict[str, Any] | None = None, diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 0df584d2527..f6dc04ab12b 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -11,7 +11,7 @@ from typing import ( Any, ) -from pydantic import BaseModel, RootModel +from pydantic import RootModel from typing_extensions import override from langchain_core.runnables.base import ( @@ -19,6 +19,7 @@ from langchain_core.runnables.base import ( Runnable, RunnableParallel, RunnableSerializable, + _get_schema_field_definition, ) from langchain_core.runnables.config import ( RunnableConfig, @@ -34,7 +35,7 @@ from langchain_core.runnables.utils import ( ) from langchain_core.utils.aiter import atee from langchain_core.utils.iter import safetee -from langchain_core.utils.pydantic import create_model_v2 +from langchain_core.utils.pydantic import TypeBaseModel, create_model_v2, get_fields if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterator, Mapping @@ -425,7 +426,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): return super().get_name(suffix, name=name) @override - def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: map_input_schema = self.mapper.get_input_schema(config) if not issubclass(map_input_schema, RootModel): # ie. it's a dict @@ -434,9 +435,11 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): return super().get_input_schema(config) @override - def get_output_schema( - self, config: RunnableConfig | None = None - ) -> type[BaseModel]: + def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: + # The return type stays `TypeBaseModel` (rather than narrowing to + # `type[BaseModel]` as `RunnableParallel.get_output_schema` does) because + # the fallback branches return the mapper's output schema or delegate to + # `super().get_output_schema()`, either of which may be a Pydantic v1 model. map_input_schema = self.mapper.get_input_schema(config) map_output_schema = self.mapper.get_output_schema(config) if not issubclass(map_input_schema, RootModel) and not issubclass( @@ -444,11 +447,11 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): ): fields = {} - for name, field_info in map_input_schema.model_fields.items(): - fields[name] = (field_info.annotation, field_info.default) + for name, field_info in get_fields(map_input_schema).items(): + fields[name] = _get_schema_field_definition(field_info) - for name, field_info in map_output_schema.model_fields.items(): - fields[name] = (field_info.annotation, field_info.default) + for name, field_info in get_fields(map_output_schema).items(): + fields[name] = _get_schema_field_definition(field_info) return create_model_v2("RunnableAssignOutput", field_definitions=fields) if not issubclass(map_output_schema, RootModel): diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index a04f0abef48..66744e1ffe3 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -65,6 +65,7 @@ from langchain_core.utils.pydantic import ( is_basemodel_subclass, is_pydantic_v1_subclass, is_pydantic_v2_subclass, + model_json_schema, ) if TYPE_CHECKING: @@ -267,7 +268,7 @@ def create_schema_from_function( parse_docstring: bool = False, error_on_invalid_docstring: bool = False, include_injected: bool = True, -) -> type[BaseModel]: +) -> TypeBaseModel: """Create a Pydantic schema from a function's signature. Args: @@ -572,14 +573,12 @@ class ChildTool(BaseTool): """ if isinstance(self.args_schema, dict): json_schema = self.args_schema - elif self.args_schema and issubclass(self.args_schema, BaseModelV1): - json_schema = self.args_schema.schema() else: input_schema = self.tool_call_schema if isinstance(input_schema, dict): json_schema = input_schema else: - json_schema = input_schema.model_json_schema() + json_schema = model_json_schema(input_schema) return cast("dict[str, Any]", json_schema["properties"]) @property @@ -615,7 +614,7 @@ class ChildTool(BaseTool): # --- Runnable --- @override - def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: """The tool's input schema. Args: @@ -693,6 +692,7 @@ class ChildTool(BaseTool): if input_args is not None: if isinstance(input_args, dict): return tool_input + result: BaseModel | BaseModelV1 if issubclass(input_args, BaseModel): # Check args_schema for InjectedToolCallId for k, v in get_all_basemodel_annotations(input_args).items(): @@ -707,8 +707,9 @@ class ChildTool(BaseTool): ) raise ValueError(msg) tool_input[k] = tool_call_id - result = input_args.model_validate(tool_input) - result_dict = result.model_dump() + result_v2 = input_args.model_validate(tool_input) + result_dict = result_v2.model_dump() + result = result_v2 elif issubclass(input_args, BaseModelV1): # Check args_schema for InjectedToolCallId for k, v in get_all_basemodel_annotations(input_args).items(): @@ -723,8 +724,9 @@ class ChildTool(BaseTool): ) raise ValueError(msg) tool_input[k] = tool_call_id - result = input_args.parse_obj(tool_input) - result_dict = result.dict() + result_v1 = input_args.parse_obj(tool_input) + result_dict = result_v1.dict() + result = result_v1 else: msg = ( f"args_schema must be a Pydantic BaseModel, got {self.args_schema}" diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index 5781afc4734..3a6311cdbb6 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -11,6 +11,7 @@ from langchain_core.runnables import Runnable from langchain_core.tools.base import ArgsSchema, BaseTool from langchain_core.tools.simple import Tool from langchain_core.tools.structured import StructuredTool +from langchain_core.utils.pydantic import TypeBaseModel @overload @@ -275,7 +276,7 @@ def tool( if isinstance(dec_func, Runnable): runnable = dec_func - if runnable.input_schema.model_json_schema().get("type") != "object": + if runnable.get_input_jsonschema().get("type") != "object": msg = "Runnable must have an object schema." raise ValueError(msg) @@ -394,7 +395,7 @@ def tool( def _get_description_from_runnable(runnable: Runnable[Any, Any]) -> str: """Generate a placeholder description of a `Runnable`.""" - input_schema = runnable.input_schema.model_json_schema() + input_schema = runnable.get_input_jsonschema() return f"Takes {input_schema}." @@ -420,7 +421,7 @@ def _get_schema_from_runnable_and_arg_types( def convert_runnable_to_tool( runnable: Runnable[Any, Any], - args_schema: type[BaseModel] | None = None, + args_schema: TypeBaseModel | None = None, *, name: str | None = None, description: str | None = None, @@ -443,7 +444,7 @@ def convert_runnable_to_tool( description = description or _get_description_from_runnable(runnable) name = name or runnable.get_name() - schema = runnable.input_schema.model_json_schema() + schema = runnable.get_input_jsonschema() if schema.get("type") == "string": return Tool( name=name, diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 7fed4281456..9d78932263f 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -69,8 +69,8 @@ PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor IS_PYDANTIC_V1 = False IS_PYDANTIC_V2 = True -PydanticBaseModel = BaseModel -TypeBaseModel = type[BaseModel] +PydanticBaseModel = BaseModel | BaseModelV1 +TypeBaseModel = type[BaseModel] | type[BaseModelV1] TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) @@ -287,7 +287,7 @@ def _create_subset_model( *, descriptions: dict[str, str] | None = None, fn_description: str | None = None, -) -> type[BaseModel]: +) -> TypeBaseModel: """Create subset model using the same pydantic version as the input model. Returns: @@ -347,6 +347,49 @@ def get_fields( raise TypeError(msg) +def model_json_schema(model: TypeBaseModel) -> dict[str, Any]: + """Return the JSON schema of a Pydantic model class of either major version. + + Dispatches to the correct method for Pydantic v1 (`schema`) or v2 + (`model_json_schema`), so callers holding a `TypeBaseModel` don't have to + branch on the model's version themselves. + + Args: + model: The Pydantic model class. + + Raises: + TypeError: If the model is not a Pydantic model class. + """ + if issubclass(model, BaseModel): + return model.model_json_schema() + if issubclass(model, BaseModelV1): + return model.schema() + msg = f"Expected a Pydantic model. Got {model}" + raise TypeError(msg) + + +def model_validate(model: TypeBaseModel, obj: Any) -> PydanticBaseModel: + """Validate `obj` against a Pydantic model class of either major version. + + Dispatches to the correct method for Pydantic v1 (`parse_obj`) or v2 + (`model_validate`), so callers holding a `TypeBaseModel` don't have to + branch on the model's version themselves. + + Args: + model: The Pydantic model class to validate against. + obj: The object to validate. + + Raises: + TypeError: If the model is not a Pydantic model class. + """ + if issubclass(model, BaseModel): + return model.model_validate(obj) + if issubclass(model, BaseModelV1): + return model.parse_obj(obj) + msg = f"Expected a Pydantic model. Got {model}" + raise TypeError(msg) + + _SchemaConfig = ConfigDict( arbitrary_types_allowed=True, frozen=True, protected_namespaces=() ) diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py index 2726261affd..0451b26ce49 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -175,7 +175,7 @@ def test_pydantic_output_parser_type_inference() -> None: # Ignoring mypy error that appears in python 3.8, but not 3.11. # This seems to be functionally correct, so we'll ignore the error. pydantic_parser = PydanticOutputParser[SampleModel](pydantic_object=SampleModel) - schema = pydantic_parser.get_output_schema().model_json_schema() + schema = pydantic_parser.get_output_jsonschema() assert schema == { "properties": { diff --git a/libs/core/tests/unit_tests/pydantic_utils.py b/libs/core/tests/unit_tests/pydantic_utils.py index 1379c53d52d..0499d9a321b 100644 --- a/libs/core/tests/unit_tests/pydantic_utils.py +++ b/libs/core/tests/unit_tests/pydantic_utils.py @@ -1,9 +1,18 @@ +import sys from inspect import isclass from typing import Any +import pytest from pydantic import BaseModel from pydantic.v1 import BaseModel as BaseModelV1 +# pydantic.v1 models cannot be exercised under the compatibility shim on +# Python 3.14+, so v1-specific tests are skipped there. +skip_if_no_pydantic_v1 = pytest.mark.skipif( + sys.version_info >= (3, 14), + reason="pydantic.v1 namespace not supported with Python 3.14+", +) + # Function to replace allOf with $ref def replace_all_of_with_ref(schema: Any) -> None: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index f07a3aa3d54..550198f2f00 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -19,7 +19,10 @@ from uuid import UUID import pytest from freezegun import freeze_time from packaging import version -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError +from pydantic.v1 import BaseModel as BaseModelV1 +from pydantic.v1 import Field as FieldV1 +from pydantic.v1 import ValidationError as ValidationErrorV1 from pytest_mock import MockerFixture from syrupy.assertion import SnapshotAssertion from typing_extensions import TypedDict, override @@ -90,9 +93,17 @@ from langchain_core.tracers import ( ) from langchain_core.tracers._compat import pydantic_copy from langchain_core.tracers.context import collect_runs -from langchain_core.utils.pydantic import PYDANTIC_VERSION +from langchain_core.utils.pydantic import ( + PYDANTIC_VERSION, + TypeBaseModel, + model_validate, +) from langchain_core.version import VERSION -from tests.unit_tests.pydantic_utils import _normalize_schema, _schema +from tests.unit_tests.pydantic_utils import ( + _normalize_schema, + _schema, + skip_if_no_pydantic_v1, +) from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION @@ -499,7 +510,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: foo_ = RunnableLambda(foo) - assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == { + assert foo_.assign(bar=lambda _: "foo").get_output_jsonschema() == { "properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}}, "required": ["root", "bar"], "title": "RunnableAssignOutput", @@ -5820,6 +5831,167 @@ def test_runnable_typed_dict_schema() -> None: other=other_runnable, ) assert ( - repr(parallel.input_schema.model_validate({"foo": "Y", "bar": "Z"})) + repr(model_validate(parallel.input_schema, {"foo": "Y", "bar": "Z"})) == "RunnableParallelInput(root={'foo': 'Y', 'bar': 'Z'})" ) + + +class _RunnableWithInputSchema(Runnable[Any, Any]): + def __init__(self, input_schema: TypeBaseModel) -> None: + self._input_schema = input_schema + + @property + @override + def InputType(self) -> Any: + return self._input_schema + + @override + def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel: + _ = config + return self._input_schema + + @override + def invoke( + self, + input: Any, + config: RunnableConfig | None = None, + **kwargs: Any, + ) -> Any: + return input + + +@skip_if_no_pydantic_v1 +def test_runnable_parallel_preserves_required_v1_input_fields() -> None: + class InputModel(BaseModelV1): + a: int + b: int = 2 + + parallel = RunnableParallel(foo=_RunnableWithInputSchema(InputModel)) + + schema = cast("type[BaseModel]", parallel.input_schema) + assert schema.model_json_schema()["required"] == ["a"] + with pytest.raises(ValidationError): + schema.model_validate({}) + + model = schema.model_validate({"a": 1}) + assert model.model_dump() == {"a": 1, "b": 2} + + +@skip_if_no_pydantic_v1 +def test_runnable_parallel_uses_base_schema_for_v1_root_model() -> None: + class InputModel(BaseModelV1): + __root__: dict[str, int] + + parallel = RunnableParallel(foo=_RunnableWithInputSchema(InputModel)) + + schema = parallel.input_schema + assert schema is InputModel + assert schema.schema() == { + "additionalProperties": {"type": "integer"}, + "title": "InputModel", + "type": "object", + } + assert schema.parse_obj({"a": 1}).__root__ == {"a": 1} + with pytest.raises(ValidationErrorV1): + schema.parse_obj({"a": "not an int"}) + + +@skip_if_no_pydantic_v1 +def test_runnable_sequence_v1_input_schema() -> None: + """A `RunnableSequence` exposes a Pydantic v1 first-step input schema. + + Regression test: deriving the sequence schema previously assumed Pydantic v2. + """ + + class InputModel(BaseModelV1): + a: int + + sequence = _RunnableWithInputSchema(InputModel) | RunnableLambda(lambda x: x) + + assert sequence.get_input_jsonschema()["properties"] == { + "a": {"title": "A", "type": "integer"} + } + + +@skip_if_no_pydantic_v1 +def test_runnable_branch_v1_input_schema() -> None: + """A `RunnableBranch` exposes a Pydantic v1 input schema. + + Regression test: `get_input_schema` previously assumed Pydantic v2 when + inspecting each branch's schema. + """ + + class InputModel(BaseModelV1): + a: int + + branch = RunnableBranch( + (lambda _: True, _RunnableWithInputSchema(InputModel)), + _RunnableWithInputSchema(InputModel), + ) + + assert branch.get_input_jsonschema()["properties"] == { + "a": {"title": "A", "type": "integer"} + } + + +@skip_if_no_pydantic_v1 +def test_runnable_parallel_preserves_v1_default_factory() -> None: + """A v1 `default_factory` field keeps its factory in the derived schema. + + Regression test for `_get_schema_field_definition`: without the dedicated + `default_factory` branch the factory is dropped and the field defaults to + `None` instead of producing the factory's value. + """ + + class InputModel(BaseModelV1): + a: int + items: list[int] = FieldV1(default_factory=list) + + parallel = RunnableParallel(foo=_RunnableWithInputSchema(InputModel)) + + schema = cast("type[BaseModel]", parallel.input_schema) + # The factory field is optional, so only `a` is required. + assert schema.model_json_schema()["required"] == ["a"] + + # The factory runs when omitted, yielding `[]` rather than `None`. + assert schema.model_validate({"a": 1}).model_dump() == {"a": 1, "items": []} + + +@skip_if_no_pydantic_v1 +def test_runnable_sequence_v1_output_schema_with_assign() -> None: + """A sequence ending in `RunnableAssign` derives a v1 upstream output schema. + + Regression test: `_seq_output_schema` previously read `.model_fields` on the + upstream schema (v2-only) and translated required v1 fields as optional. + """ + + class InputModel(BaseModelV1): + a: int + + sequence = _RunnableWithInputSchema(InputModel) | RunnableAssign( + RunnableParallel(bar=RunnableLambda(lambda _: "bar")) + ) + + schema = sequence.get_output_jsonschema() + assert set(schema["properties"]) == {"a", "bar"} + # The required v1 field survives as required rather than becoming optional. + assert "a" in schema["required"] + + +@skip_if_no_pydantic_v1 +def test_runnable_sequence_v1_output_schema_with_pick() -> None: + """A sequence ending in `RunnablePick` derives a v1 upstream output schema. + + Regression test: `_seq_output_schema` previously read `.model_fields` on the + upstream schema (v2-only) for the `RunnablePick` branch. + """ + + class InputModel(BaseModelV1): + a: int + b: int + + sequence = _RunnableWithInputSchema(InputModel) | RunnablePick(["a"]) + + schema = sequence.get_output_jsonschema() + assert set(schema["properties"]) == {"a"} + assert "a" in schema["required"] diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 43c23a109bb..76c72d27708 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -40,6 +40,7 @@ from langchain_core.messages import ToolCall, ToolMessage from langchain_core.messages.tool import ToolOutputMixin from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import ( + Runnable, RunnableConfig, RunnableLambda, ensure_config, @@ -49,6 +50,7 @@ from langchain_core.tools import ( StructuredTool, Tool, ToolException, + convert_runnable_to_tool, tool, ) from langchain_core.tools.base import ( @@ -68,11 +70,16 @@ from langchain_core.utils.function_calling import ( convert_to_openai_tool, ) from langchain_core.utils.pydantic import ( + TypeBaseModel, _create_subset_model, create_model_v2, ) from tests.unit_tests.fake.callbacks import FakeCallbackHandler -from tests.unit_tests.pydantic_utils import _normalize_schema, _schema +from tests.unit_tests.pydantic_utils import ( + _normalize_schema, + _schema, + skip_if_no_pydantic_v1, +) try: from langgraph.prebuilt import ToolRuntime # type: ignore[import-not-found] @@ -1964,7 +1971,7 @@ def test_tool_inherited_injected_arg() -> None: return y tool_ = InheritedInjectedArgTool() - assert tool_.get_input_schema().model_json_schema() == { + assert tool_.get_input_jsonschema() == { "title": "FooSchema", # Matches the title from the provided schema "description": "foo.", "type": "object", @@ -2138,7 +2145,7 @@ def test_args_schema_explicitly_typed() -> None: tool = SomeTool(name="some_tool", description="some description") - assert tool.get_input_schema().model_json_schema() == { + assert tool.get_input_jsonschema() == { "properties": { "a": {"title": "A", "type": "integer"}, "b": {"title": "B", "type": "string"}, @@ -2160,6 +2167,92 @@ def test_args_schema_explicitly_typed() -> None: } +@skip_if_no_pydantic_v1 +def test_get_input_jsonschema_v1_args_schema() -> None: + """`get_input_jsonschema()` works for a tool with a Pydantic v1 `args_schema`. + + Regression test: previously this raised `AttributeError` because a Pydantic + v1 model does not implement `model_json_schema`. + """ + + class FooV1(BaseModelV1): + a: int + b: str + + class SomeTool(BaseTool): + args_schema: type[BaseModelV1] = FooV1 + + @override + def _run(self, *args: Any, **kwargs: Any) -> str: + return "foo" + + tool = SomeTool(name="some_tool", description="some description", args_schema=FooV1) + + assert tool.get_input_jsonschema()["properties"] == { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "string"}, + } + + +@skip_if_no_pydantic_v1 +def test_v1_args_schema_excludes_injected_args() -> None: + """A Pydantic v1 `args_schema` must not expose injected args via `.args`. + + Injected arguments are supplied by the framework rather than the model, so + they must be hidden from the tool-call schema regardless of the `args_schema` + Pydantic version. Previously a v1 `args_schema` bypassed `tool_call_schema` + and leaked injected arguments into `.args`. + """ + + class FooV1(BaseModelV1): + a: int + injected: Annotated[str, InjectedToolArg()] = "default" + + class SomeTool(BaseTool): + args_schema: type[BaseModelV1] = FooV1 + + @override + def _run(self, *args: Any, **kwargs: Any) -> str: + return "foo" + + tool = SomeTool(name="some_tool", description="some description", args_schema=FooV1) + + assert set(tool.args) == {"a"} + assert "injected" not in tool.args + + +@skip_if_no_pydantic_v1 +def test_convert_runnable_to_tool_v1_input_schema() -> None: + """`convert_runnable_to_tool` works for a runnable with a v1 input schema. + + Regression test: deriving the tool schema from the runnable previously + assumed a Pydantic v2 input schema and raised on v1. + """ + + class FooV1(BaseModelV1): + a: int + + class V1InputRunnable(Runnable[Any, Any]): + @override + def get_input_schema( + self, config: RunnableConfig | None = None + ) -> TypeBaseModel: + return FooV1 + + @override + def invoke( + self, + input: Any, + config: RunnableConfig | None = None, + **kwargs: Any, + ) -> Any: + return input + + tool = convert_runnable_to_tool(V1InputRunnable()) + + assert set(tool.args) == {"a"} + + @pytest.mark.parametrize("pydantic_model", TEST_MODELS) def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None: """This should test that one can type the args schema as a Pydantic model.""" @@ -2784,7 +2877,7 @@ def test_structured_tool_args_schema_dict(caplog: pytest.LogCaptureFixture) -> N assert _get_tool_call_json_schema(tool) == args_schema # test that the input schema is the same as the parent (Runnable) input schema assert ( - tool.get_input_schema().model_json_schema() + tool.get_input_jsonschema() == create_model_v2( tool.get_name("Input"), root=tool.InputType, @@ -2822,7 +2915,7 @@ def test_simple_tool_args_schema_dict() -> None: assert _get_tool_call_json_schema(tool) == args_schema # test that the input schema is the same as the parent (Runnable) input schema assert ( - tool.get_input_schema().model_json_schema() + tool.get_input_jsonschema() == create_model_v2( tool.get_name("Input"), root=tool.InputType, diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index a329e19d193..3830c269236 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -14,6 +14,8 @@ from langchain_core.utils.pydantic import ( get_fields, is_basemodel_instance, is_basemodel_subclass, + model_json_schema, + model_validate, pre_init, ) @@ -153,6 +155,56 @@ def test_fields_pydantic_v1_from_2() -> None: assert fields == {"x": Foo.__fields__["x"]} +def test_model_json_schema_v2() -> None: + class Foo(BaseModel): + x: int + + assert model_json_schema(Foo) == Foo.model_json_schema() + + +@pytest.mark.skipif( + sys.version_info >= (3, 14), + reason="pydantic.v1 namespace not supported with Python 3.14+", +) +def test_model_json_schema_v1() -> None: + class Foo(BaseModelV1): + x: int + + assert model_json_schema(Foo) == Foo.schema() + + +def test_model_json_schema_non_model() -> None: + with pytest.raises(TypeError, match="Expected a Pydantic model"): + model_json_schema(dict) # type: ignore[arg-type] + + +def test_model_validate_v2() -> None: + class Foo(BaseModel): + x: int + + result = model_validate(Foo, {"x": 1}) + assert isinstance(result, Foo) + assert result.x == 1 + + +@pytest.mark.skipif( + sys.version_info >= (3, 14), + reason="pydantic.v1 namespace not supported with Python 3.14+", +) +def test_model_validate_v1() -> None: + class Foo(BaseModelV1): + x: int + + result = model_validate(Foo, {"x": 1}) + assert isinstance(result, Foo) + assert result.x == 1 + + +def test_model_validate_non_model() -> None: + with pytest.raises(TypeError, match="Expected a Pydantic model"): + model_validate(dict, {"x": 1}) # type: ignore[arg-type] + + def test_create_model_v2() -> None: """Test that create model v2 works as expected.""" with warnings.catch_warnings(record=True) as record: diff --git a/libs/langchain/langchain_classic/chains/hyde/base.py b/libs/langchain/langchain_classic/chains/hyde/base.py index 7ee9780b4ab..e4b4a43ef7b 100644 --- a/libs/langchain/langchain_classic/chains/hyde/base.py +++ b/libs/langchain/langchain_classic/chains/hyde/base.py @@ -40,7 +40,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): @property def input_keys(self) -> list[str]: """Input keys for Hyde's LLM chain.""" - return self.llm_chain.input_schema.model_json_schema()["required"] + return self.llm_chain.get_input_jsonschema()["required"] @property def output_keys(self) -> list[str]: