core[patch]: allow passing JSON schema as args_schema to tools (#29812)

This commit is contained in:
Vadym Barda 2025-02-18 14:44:31 -05:00 committed by GitHub
parent 5034a8dc5c
commit d04fa1ae50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 221 additions and 42 deletions

View File

@ -317,6 +317,9 @@ class ToolException(Exception): # noqa: N818
"""
ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
"""Interface LangChain tools must implement."""
@ -354,7 +357,7 @@ class ChildTool(BaseTool):
You can provide few-shot examples as a part of the description.
"""
args_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = Field(
args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field(
default=None, description="The tool schema."
)
"""Pydantic model class to validate and parse the tool's input arguments.
@ -364,6 +367,8 @@ class ChildTool(BaseTool):
- A subclass of pydantic.BaseModel.
or
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
or
- a JSON schema dict
"""
return_direct: bool = False
"""Whether to return the tool's output directly.
@ -423,10 +428,11 @@ class ChildTool(BaseTool):
"args_schema" in kwargs
and kwargs["args_schema"] is not None
and not is_basemodel_subclass(kwargs["args_schema"])
and not isinstance(kwargs["args_schema"], dict)
):
msg = (
f"args_schema must be a subclass of pydantic BaseModel. "
f"Got: {kwargs['args_schema']}."
"args_schema must be a subclass of pydantic BaseModel or "
f"a JSON schema dict. Got: {kwargs['args_schema']}."
)
raise TypeError(msg)
super().__init__(**kwargs)
@ -443,10 +449,18 @@ class ChildTool(BaseTool):
@property
def args(self) -> dict:
return self.get_input_schema().model_json_schema()["properties"]
if isinstance(self.args_schema, dict):
json_schema = self.args_schema
else:
input_schema = self.get_input_schema()
json_schema = input_schema.model_json_schema()
return json_schema["properties"]
@property
def tool_call_schema(self) -> type[BaseModel]:
def tool_call_schema(self) -> ArgsSchema:
if isinstance(self.args_schema, dict):
return self.args_schema
full_schema = self.get_input_schema()
fields = []
for name, type_ in get_all_basemodel_annotations(full_schema).items():
@ -470,6 +484,8 @@ class ChildTool(BaseTool):
The input schema for the tool.
"""
if self.args_schema is not None:
if isinstance(self.args_schema, dict):
return super().get_input_schema(config)
return self.args_schema
else:
return create_schema_from_function(self.name, self._run)
@ -505,6 +521,12 @@ class ChildTool(BaseTool):
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
if isinstance(input_args, dict):
msg = (
"String tool inputs are not allowed when "
"using tools with JSON schema args_schema."
)
raise ValueError(msg)
key_ = next(iter(get_fields(input_args).keys()))
if hasattr(input_args, "model_validate"):
input_args.model_validate({key_: tool_input})
@ -513,7 +535,9 @@ class ChildTool(BaseTool):
return tool_input
else:
if input_args is not None:
if issubclass(input_args, BaseModel):
if isinstance(input_args, dict):
return tool_input
elif issubclass(input_args, BaseModel):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
@ -605,7 +629,12 @@ class ChildTool(BaseTool):
def _to_args_and_kwargs(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> tuple[tuple, dict]:
if self.args_schema is not None and not get_fields(self.args_schema):
if (
self.args_schema is not None
and isinstance(self.args_schema, type)
and is_basemodel_subclass(self.args_schema)
and not get_fields(self.args_schema)
):
# StructuredTool with no args
return (), {}
tool_input = self._parse_input(tool_input, tool_call_id)

View File

@ -9,8 +9,6 @@ from typing import (
Union,
)
from pydantic import BaseModel
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
@ -18,6 +16,7 @@ from langchain_core.callbacks import (
from langchain_core.messages import ToolCall
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
ArgsSchema,
BaseTool,
ToolException,
_get_runnable_config_param,
@ -57,7 +56,11 @@ class Tool(BaseTool):
The input arguments for the tool.
"""
if self.args_schema is not None:
return self.args_schema.model_json_schema()["properties"]
if isinstance(self.args_schema, dict):
json_schema = self.args_schema
else:
json_schema = self.args_schema.model_json_schema()
return json_schema["properties"]
# For backwards compatibility, if the function signature is ambiguous,
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}
@ -132,7 +135,7 @@ class Tool(BaseTool):
name: str, # We keep these required to support backwards compatibility
description: str,
return_direct: bool = False,
args_schema: Optional[type[BaseModel]] = None,
args_schema: Optional[ArgsSchema] = None,
coroutine: Optional[
Callable[..., Awaitable[Any]]
] = None, # This is last for compatibility, but should be after func

View File

@ -12,7 +12,7 @@ from typing import (
Union,
)
from pydantic import BaseModel, Field, SkipValidation
from pydantic import Field, SkipValidation
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
@ -22,18 +22,18 @@ from langchain_core.messages import ToolCall
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
FILTERED_ARGS,
ArgsSchema,
BaseTool,
_get_runnable_config_param,
create_schema_from_function,
)
from langchain_core.utils.pydantic import TypeBaseModel
class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: Annotated[TypeBaseModel, SkipValidation()] = Field(
args_schema: Annotated[ArgsSchema, SkipValidation()] = Field(
..., description="The tool schema."
)
"""The input arguments' schema."""
@ -62,7 +62,12 @@ class StructuredTool(BaseTool):
@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.model_json_schema()["properties"]
if isinstance(self.args_schema, dict):
json_schema = self.args_schema
else:
input_schema = self.get_input_schema()
json_schema = input_schema.model_json_schema()
return json_schema["properties"]
def _run(
self,
@ -110,7 +115,7 @@ class StructuredTool(BaseTool):
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[type[BaseModel]] = None,
args_schema: Optional[ArgsSchema] = None,
infer_schema: bool = True,
*,
response_format: Literal["content", "content_and_artifact"] = "content",

View File

@ -20,6 +20,7 @@ from typing import (
)
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
from langchain_core._api import beta, deprecated
@ -75,6 +76,40 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
return new_kv
def _convert_json_schema_to_openai_function(
schema: dict,
*,
name: Optional[str] = None,
description: Optional[str] = None,
rm_titles: bool = True,
) -> FunctionDescription:
"""Converts a Pydantic model to a function description for the OpenAI API.
Args:
schema: The JSON schema to convert.
name: The name of the function. If not provided, the title of the schema will be
used.
description: The description of the function. If not provided, the description
of the schema will be used.
rm_titles: Whether to remove titles from the schema. Defaults to True.
Returns:
The function description.
"""
schema = dereference_refs(schema)
if "definitions" in schema: # pydantic 1
schema.pop("definitions", None)
if "$defs" in schema: # pydantic 2
schema.pop("$defs", None)
title = schema.pop("title", "")
default_description = schema.pop("description", "")
return {
"name": name or title,
"description": description or default_description,
"parameters": _rm_titles(schema) if rm_titles else schema,
}
def _convert_pydantic_to_openai_function(
model: type,
*,
@ -102,18 +137,9 @@ def _convert_pydantic_to_openai_function(
else:
msg = "Model must be a Pydantic model."
raise TypeError(msg)
schema = dereference_refs(schema)
if "definitions" in schema: # pydantic 1
schema.pop("definitions", None)
if "$defs" in schema: # pydantic 2
schema.pop("$defs", None)
title = schema.pop("title", "")
default_description = schema.pop("description", "")
return {
"name": name or title,
"description": description or default_description,
"parameters": _rm_titles(schema) if rm_titles else schema,
}
return _convert_json_schema_to_openai_function(
schema, name=name, description=description, rm_titles=rm_titles
)
convert_pydantic_to_openai_function = deprecated(
@ -289,9 +315,20 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
is_simple_oai_tool = isinstance(tool, simple.Tool) and not tool.args_schema
if tool.tool_call_schema and not is_simple_oai_tool:
return _convert_pydantic_to_openai_function(
tool.tool_call_schema, name=tool.name, description=tool.description
)
if isinstance(tool.tool_call_schema, dict):
return _convert_json_schema_to_openai_function(
tool.tool_call_schema, name=tool.name, description=tool.description
)
elif issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)):
return _convert_pydantic_to_openai_function(
tool.tool_call_schema, name=tool.name, description=tool.description
)
else:
error_msg = (
f"Unsupported tool call schema: {tool.tool_call_schema}. "
"Tool call schema must be a JSON schema dict or a Pydantic model."
)
raise ValueError(error_msg)
else:
return {
"name": tool.name,

View File

@ -17,6 +17,7 @@ from typing import (
Optional,
TypeVar,
Union,
cast,
)
import pytest
@ -59,11 +60,26 @@ from langchain_core.tools.base import (
get_all_basemodel_annotations,
)
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
_create_subset_model,
create_model_v2,
)
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _schema
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
tool_schema = tool.tool_call_schema
if isinstance(tool_schema, dict):
return tool_schema
if hasattr(tool_schema, "model_json_schema"):
return tool_schema.model_json_schema()
else:
return tool_schema.schema()
def test_unnamed_decorator() -> None:
"""Test functionality with unnamed decorator."""
@ -1721,7 +1737,7 @@ def test_tool_inherited_injected_arg() -> None:
"required": ["y", "x"],
}
# Should not include `y` since it's annotated as an injected tool arg
assert tool_.tool_call_schema.model_json_schema() == {
assert _get_tool_call_json_schema(tool_) == {
"title": "foo",
"description": "foo.",
"type": "object",
@ -1840,12 +1856,7 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
"type": "object",
}
tool_schema = tool.tool_call_schema
tool_json_schema = (
tool_schema.model_json_schema()
if hasattr(tool_schema, "model_json_schema")
else tool_schema.schema()
)
tool_json_schema = _get_tool_call_json_schema(tool)
assert tool_json_schema == {
"description": "some description",
"properties": {
@ -1892,7 +1903,7 @@ def test_args_schema_explicitly_typed() -> None:
"type": "object",
}
assert tool.tool_call_schema.model_json_schema() == {
assert _get_tool_call_json_schema(tool) == {
"description": "some description",
"properties": {
"a": {"title": "A", "type": "integer"},
@ -1920,7 +1931,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
args_schema = foo_tool.args_schema
args_schema = cast(BaseModel, foo_tool.args_schema)
args_json_schema = (
args_schema.model_json_schema()
if hasattr(args_schema, "model_json_schema")
@ -2210,7 +2221,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
"""Foo."""
return x
assert foo.tool_call_schema.model_json_schema() == {
assert _get_tool_call_json_schema(foo) == {
"description": "Foo.",
"properties": {
"x": {
@ -2376,3 +2387,73 @@ def test_tool_mutate_input() -> None:
my_input = {"x": "hi"}
MyTool().invoke(my_input)
assert my_input == {"x": "hi"}
def test_structured_tool_args_schema_dict() -> None:
args_schema = {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "integer"},
},
"required": ["a", "b"],
"title": "add",
"type": "object",
}
tool = StructuredTool(
name="add",
description="add two numbers",
args_schema=args_schema,
func=lambda a, b: a + b,
)
assert tool.invoke({"a": 1, "b": 2}) == 3
assert tool.args_schema == args_schema
# test that the tool call schema is the same as the args schema
assert _get_tool_call_json_schema(tool) == args_schema
# test that the input schema is the same as the parent (Runnable) input schema
assert (
tool.get_input_schema().model_json_schema()
== create_model_v2(
tool.get_name("Input"),
root=tool.InputType,
module_name=tool.__class__.__module__,
).model_json_schema()
)
# test that args are extracted correctly
assert tool.args == {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "integer"},
}
def test_simple_tool_args_schema_dict() -> None:
args_schema = {
"properties": {
"a": {"title": "A", "type": "integer"},
},
"required": ["a"],
"title": "square",
"type": "object",
}
tool = Tool(
name="square",
description="square a number",
args_schema=args_schema,
func=lambda a: a * a,
)
assert tool.invoke({"a": 2}) == 4
assert tool.args_schema == args_schema
# test that the tool call schema is the same as the args schema
assert _get_tool_call_json_schema(tool) == args_schema
# test that the input schema is the same as the parent (Runnable) input schema
assert (
tool.get_input_schema().model_json_schema()
== create_model_v2(
tool.get_name("Input"),
root=tool.InputType,
module_name=tool.__class__.__module__,
).model_json_schema()
)
# test that args are extracted correctly
assert tool.args == {
"a": {"title": "A", "type": "integer"},
}

View File

@ -127,6 +127,28 @@ def dummy_structured_tool() -> StructuredTool:
)
@pytest.fixture()
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
args_schema = {
"type": "object",
"properties": {
"arg1": {"type": "integer", "description": "foo"},
"arg2": {
"type": "string",
"enum": ["bar", "baz"],
"description": "one of 'bar', 'baz'",
},
},
"required": ["arg1", "arg2"],
}
return StructuredTool.from_function(
lambda x: None,
name="dummy_function",
description="Dummy function.",
args_schema=args_schema,
)
@pytest.fixture()
def dummy_pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801
@ -293,6 +315,7 @@ def test_convert_to_openai_function(
function: Callable,
function_docstring_annotations: Callable,
dummy_structured_tool: StructuredTool,
dummy_structured_tool_args_schema_dict: StructuredTool,
dummy_tool: BaseTool,
json_schema: dict,
anthropic_tool: dict,
@ -327,6 +350,7 @@ def test_convert_to_openai_function(
function,
function_docstring_annotations,
dummy_structured_tool,
dummy_structured_tool_args_schema_dict,
dummy_tool,
json_schema,
anthropic_tool,