From d04fa1ae50f432c3ad108c68eb2e27e781d628ba Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Tue, 18 Feb 2025 14:44:31 -0500 Subject: [PATCH] core[patch]: allow passing JSON schema as args_schema to tools (#29812) --- libs/core/langchain_core/tools/base.py | 43 ++++++-- libs/core/langchain_core/tools/simple.py | 11 +- libs/core/langchain_core/tools/structured.py | 15 ++- .../langchain_core/utils/function_calling.py | 67 +++++++++--- libs/core/tests/unit_tests/test_tools.py | 103 ++++++++++++++++-- .../unit_tests/utils/test_function_calling.py | 24 ++++ 6 files changed, 221 insertions(+), 42 deletions(-) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index f0833cdc24a..ff62f3091d4 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -317,6 +317,9 @@ class ToolException(Exception): # noqa: N818 """ +ArgsSchema = Union[TypeBaseModel, dict[str, Any]] + + class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]): """Interface LangChain tools must implement.""" @@ -354,7 +357,7 @@ class ChildTool(BaseTool): You can provide few-shot examples as a part of the description. """ - args_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = Field( + args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field( default=None, description="The tool schema." ) """Pydantic model class to validate and parse the tool's input arguments. @@ -364,6 +367,8 @@ class ChildTool(BaseTool): - A subclass of pydantic.BaseModel. or - A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2 + or + - a JSON schema dict """ return_direct: bool = False """Whether to return the tool's output directly. @@ -423,10 +428,11 @@ class ChildTool(BaseTool): "args_schema" in kwargs and kwargs["args_schema"] is not None and not is_basemodel_subclass(kwargs["args_schema"]) + and not isinstance(kwargs["args_schema"], dict) ): msg = ( - f"args_schema must be a subclass of pydantic BaseModel. " - f"Got: {kwargs['args_schema']}." + "args_schema must be a subclass of pydantic BaseModel or " + f"a JSON schema dict. Got: {kwargs['args_schema']}." ) raise TypeError(msg) super().__init__(**kwargs) @@ -443,10 +449,18 @@ class ChildTool(BaseTool): @property def args(self) -> dict: - return self.get_input_schema().model_json_schema()["properties"] + if isinstance(self.args_schema, dict): + json_schema = self.args_schema + else: + input_schema = self.get_input_schema() + json_schema = input_schema.model_json_schema() + return json_schema["properties"] @property - def tool_call_schema(self) -> type[BaseModel]: + def tool_call_schema(self) -> ArgsSchema: + if isinstance(self.args_schema, dict): + return self.args_schema + full_schema = self.get_input_schema() fields = [] for name, type_ in get_all_basemodel_annotations(full_schema).items(): @@ -470,6 +484,8 @@ class ChildTool(BaseTool): The input schema for the tool. """ if self.args_schema is not None: + if isinstance(self.args_schema, dict): + return super().get_input_schema(config) return self.args_schema else: return create_schema_from_function(self.name, self._run) @@ -505,6 +521,12 @@ class ChildTool(BaseTool): input_args = self.args_schema if isinstance(tool_input, str): if input_args is not None: + if isinstance(input_args, dict): + msg = ( + "String tool inputs are not allowed when " + "using tools with JSON schema args_schema." + ) + raise ValueError(msg) key_ = next(iter(get_fields(input_args).keys())) if hasattr(input_args, "model_validate"): input_args.model_validate({key_: tool_input}) @@ -513,7 +535,9 @@ class ChildTool(BaseTool): return tool_input else: if input_args is not None: - if issubclass(input_args, BaseModel): + if isinstance(input_args, dict): + return tool_input + elif issubclass(input_args, BaseModel): for k, v in get_all_basemodel_annotations(input_args).items(): if ( _is_injected_arg_type(v, injected_type=InjectedToolCallId) @@ -605,7 +629,12 @@ class ChildTool(BaseTool): def _to_args_and_kwargs( self, tool_input: Union[str, dict], tool_call_id: Optional[str] ) -> tuple[tuple, dict]: - if self.args_schema is not None and not get_fields(self.args_schema): + if ( + self.args_schema is not None + and isinstance(self.args_schema, type) + and is_basemodel_subclass(self.args_schema) + and not get_fields(self.args_schema) + ): # StructuredTool with no args return (), {} tool_input = self._parse_input(tool_input, tool_call_id) diff --git a/libs/core/langchain_core/tools/simple.py b/libs/core/langchain_core/tools/simple.py index d9e38ba227c..370d2091a68 100644 --- a/libs/core/langchain_core/tools/simple.py +++ b/libs/core/langchain_core/tools/simple.py @@ -9,8 +9,6 @@ from typing import ( Union, ) -from pydantic import BaseModel - from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, @@ -18,6 +16,7 @@ from langchain_core.callbacks import ( from langchain_core.messages import ToolCall from langchain_core.runnables import RunnableConfig, run_in_executor from langchain_core.tools.base import ( + ArgsSchema, BaseTool, ToolException, _get_runnable_config_param, @@ -57,7 +56,11 @@ class Tool(BaseTool): The input arguments for the tool. """ if self.args_schema is not None: - return self.args_schema.model_json_schema()["properties"] + if isinstance(self.args_schema, dict): + json_schema = self.args_schema + else: + json_schema = self.args_schema.model_json_schema() + return json_schema["properties"] # For backwards compatibility, if the function signature is ambiguous, # assume it takes a single string input. return {"tool_input": {"type": "string"}} @@ -132,7 +135,7 @@ class Tool(BaseTool): name: str, # We keep these required to support backwards compatibility description: str, return_direct: bool = False, - args_schema: Optional[type[BaseModel]] = None, + args_schema: Optional[ArgsSchema] = None, coroutine: Optional[ Callable[..., Awaitable[Any]] ] = None, # This is last for compatibility, but should be after func diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index ef185b3e1e8..6a32583696a 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -12,7 +12,7 @@ from typing import ( Union, ) -from pydantic import BaseModel, Field, SkipValidation +from pydantic import Field, SkipValidation from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, @@ -22,18 +22,18 @@ from langchain_core.messages import ToolCall from langchain_core.runnables import RunnableConfig, run_in_executor from langchain_core.tools.base import ( FILTERED_ARGS, + ArgsSchema, BaseTool, _get_runnable_config_param, create_schema_from_function, ) -from langchain_core.utils.pydantic import TypeBaseModel class StructuredTool(BaseTool): """Tool that can operate on any number of inputs.""" description: str = "" - args_schema: Annotated[TypeBaseModel, SkipValidation()] = Field( + args_schema: Annotated[ArgsSchema, SkipValidation()] = Field( ..., description="The tool schema." ) """The input arguments' schema.""" @@ -62,7 +62,12 @@ class StructuredTool(BaseTool): @property def args(self) -> dict: """The tool's input arguments.""" - return self.args_schema.model_json_schema()["properties"] + if isinstance(self.args_schema, dict): + json_schema = self.args_schema + else: + input_schema = self.get_input_schema() + json_schema = input_schema.model_json_schema() + return json_schema["properties"] def _run( self, @@ -110,7 +115,7 @@ class StructuredTool(BaseTool): name: Optional[str] = None, description: Optional[str] = None, return_direct: bool = False, - args_schema: Optional[type[BaseModel]] = None, + args_schema: Optional[ArgsSchema] = None, infer_schema: bool = True, *, response_format: Literal["content", "content_and_artifact"] = "content", diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index cdcd3dc8649..0eadd46b2b0 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -20,6 +20,7 @@ from typing import ( ) from pydantic import BaseModel +from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import TypedDict, get_args, get_origin, is_typeddict from langchain_core._api import beta, deprecated @@ -75,6 +76,40 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict: return new_kv +def _convert_json_schema_to_openai_function( + schema: dict, + *, + name: Optional[str] = None, + description: Optional[str] = None, + rm_titles: bool = True, +) -> FunctionDescription: + """Converts a Pydantic model to a function description for the OpenAI API. + + Args: + schema: The JSON schema to convert. + name: The name of the function. If not provided, the title of the schema will be + used. + description: The description of the function. If not provided, the description + of the schema will be used. + rm_titles: Whether to remove titles from the schema. Defaults to True. + + Returns: + The function description. + """ + schema = dereference_refs(schema) + if "definitions" in schema: # pydantic 1 + schema.pop("definitions", None) + if "$defs" in schema: # pydantic 2 + schema.pop("$defs", None) + title = schema.pop("title", "") + default_description = schema.pop("description", "") + return { + "name": name or title, + "description": description or default_description, + "parameters": _rm_titles(schema) if rm_titles else schema, + } + + def _convert_pydantic_to_openai_function( model: type, *, @@ -102,18 +137,9 @@ def _convert_pydantic_to_openai_function( else: msg = "Model must be a Pydantic model." raise TypeError(msg) - schema = dereference_refs(schema) - if "definitions" in schema: # pydantic 1 - schema.pop("definitions", None) - if "$defs" in schema: # pydantic 2 - schema.pop("$defs", None) - title = schema.pop("title", "") - default_description = schema.pop("description", "") - return { - "name": name or title, - "description": description or default_description, - "parameters": _rm_titles(schema) if rm_titles else schema, - } + return _convert_json_schema_to_openai_function( + schema, name=name, description=description, rm_titles=rm_titles + ) convert_pydantic_to_openai_function = deprecated( @@ -289,9 +315,20 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: is_simple_oai_tool = isinstance(tool, simple.Tool) and not tool.args_schema if tool.tool_call_schema and not is_simple_oai_tool: - return _convert_pydantic_to_openai_function( - tool.tool_call_schema, name=tool.name, description=tool.description - ) + if isinstance(tool.tool_call_schema, dict): + return _convert_json_schema_to_openai_function( + tool.tool_call_schema, name=tool.name, description=tool.description + ) + elif issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)): + return _convert_pydantic_to_openai_function( + tool.tool_call_schema, name=tool.name, description=tool.description + ) + else: + error_msg = ( + f"Unsupported tool call schema: {tool.tool_call_schema}. " + "Tool call schema must be a JSON schema dict or a Pydantic model." + ) + raise ValueError(error_msg) else: return { "name": tool.name, diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 4fd7fb567e8..afe8c09113a 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -17,6 +17,7 @@ from typing import ( Optional, TypeVar, Union, + cast, ) import pytest @@ -59,11 +60,26 @@ from langchain_core.tools.base import ( get_all_basemodel_annotations, ) from langchain_core.utils.function_calling import convert_to_openai_function -from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model +from langchain_core.utils.pydantic import ( + PYDANTIC_MAJOR_VERSION, + _create_subset_model, + create_model_v2, +) from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.pydantic_utils import _schema +def _get_tool_call_json_schema(tool: BaseTool) -> dict: + tool_schema = tool.tool_call_schema + if isinstance(tool_schema, dict): + return tool_schema + + if hasattr(tool_schema, "model_json_schema"): + return tool_schema.model_json_schema() + else: + return tool_schema.schema() + + def test_unnamed_decorator() -> None: """Test functionality with unnamed decorator.""" @@ -1721,7 +1737,7 @@ def test_tool_inherited_injected_arg() -> None: "required": ["y", "x"], } # Should not include `y` since it's annotated as an injected tool arg - assert tool_.tool_call_schema.model_json_schema() == { + assert _get_tool_call_json_schema(tool_) == { "title": "foo", "description": "foo.", "type": "object", @@ -1840,12 +1856,7 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None: "type": "object", } - tool_schema = tool.tool_call_schema - tool_json_schema = ( - tool_schema.model_json_schema() - if hasattr(tool_schema, "model_json_schema") - else tool_schema.schema() - ) + tool_json_schema = _get_tool_call_json_schema(tool) assert tool_json_schema == { "description": "some description", "properties": { @@ -1892,7 +1903,7 @@ def test_args_schema_explicitly_typed() -> None: "type": "object", } - assert tool.tool_call_schema.model_json_schema() == { + assert _get_tool_call_json_schema(tool) == { "description": "some description", "properties": { "a": {"title": "A", "type": "integer"}, @@ -1920,7 +1931,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) - assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo" - args_schema = foo_tool.args_schema + args_schema = cast(BaseModel, foo_tool.args_schema) args_json_schema = ( args_schema.model_json_schema() if hasattr(args_schema, "model_json_schema") @@ -2210,7 +2221,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None: """Foo.""" return x - assert foo.tool_call_schema.model_json_schema() == { + assert _get_tool_call_json_schema(foo) == { "description": "Foo.", "properties": { "x": { @@ -2376,3 +2387,73 @@ def test_tool_mutate_input() -> None: my_input = {"x": "hi"} MyTool().invoke(my_input) assert my_input == {"x": "hi"} + + +def test_structured_tool_args_schema_dict() -> None: + args_schema = { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "integer"}, + }, + "required": ["a", "b"], + "title": "add", + "type": "object", + } + tool = StructuredTool( + name="add", + description="add two numbers", + args_schema=args_schema, + func=lambda a, b: a + b, + ) + assert tool.invoke({"a": 1, "b": 2}) == 3 + assert tool.args_schema == args_schema + # test that the tool call schema is the same as the args schema + assert _get_tool_call_json_schema(tool) == args_schema + # test that the input schema is the same as the parent (Runnable) input schema + assert ( + tool.get_input_schema().model_json_schema() + == create_model_v2( + tool.get_name("Input"), + root=tool.InputType, + module_name=tool.__class__.__module__, + ).model_json_schema() + ) + # test that args are extracted correctly + assert tool.args == { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "integer"}, + } + + +def test_simple_tool_args_schema_dict() -> None: + args_schema = { + "properties": { + "a": {"title": "A", "type": "integer"}, + }, + "required": ["a"], + "title": "square", + "type": "object", + } + tool = Tool( + name="square", + description="square a number", + args_schema=args_schema, + func=lambda a: a * a, + ) + assert tool.invoke({"a": 2}) == 4 + assert tool.args_schema == args_schema + # test that the tool call schema is the same as the args schema + assert _get_tool_call_json_schema(tool) == args_schema + # test that the input schema is the same as the parent (Runnable) input schema + assert ( + tool.get_input_schema().model_json_schema() + == create_model_v2( + tool.get_name("Input"), + root=tool.InputType, + module_name=tool.__class__.__module__, + ).model_json_schema() + ) + # test that args are extracted correctly + assert tool.args == { + "a": {"title": "A", "type": "integer"}, + } diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index d7a90635e36..38206bceef5 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -127,6 +127,28 @@ def dummy_structured_tool() -> StructuredTool: ) +@pytest.fixture() +def dummy_structured_tool_args_schema_dict() -> StructuredTool: + args_schema = { + "type": "object", + "properties": { + "arg1": {"type": "integer", "description": "foo"}, + "arg2": { + "type": "string", + "enum": ["bar", "baz"], + "description": "one of 'bar', 'baz'", + }, + }, + "required": ["arg1", "arg2"], + } + return StructuredTool.from_function( + lambda x: None, + name="dummy_function", + description="Dummy function.", + args_schema=args_schema, + ) + + @pytest.fixture() def dummy_pydantic() -> type[BaseModel]: class dummy_function(BaseModel): # noqa: N801 @@ -293,6 +315,7 @@ def test_convert_to_openai_function( function: Callable, function_docstring_annotations: Callable, dummy_structured_tool: StructuredTool, + dummy_structured_tool_args_schema_dict: StructuredTool, dummy_tool: BaseTool, json_schema: dict, anthropic_tool: dict, @@ -327,6 +350,7 @@ def test_convert_to_openai_function( function, function_docstring_annotations, dummy_structured_tool, + dummy_structured_tool_args_schema_dict, dummy_tool, json_schema, anthropic_tool,