core[patch]: allow passing JSON schema as args_schema to tools (#29812)

This commit is contained in:
Vadym Barda
2025-02-18 14:44:31 -05:00
committed by GitHub
parent 5034a8dc5c
commit d04fa1ae50
6 changed files with 221 additions and 42 deletions

View File

@@ -17,6 +17,7 @@ from typing import (
Optional,
TypeVar,
Union,
cast,
)
import pytest
@@ -59,11 +60,26 @@ from langchain_core.tools.base import (
get_all_basemodel_annotations,
)
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
_create_subset_model,
create_model_v2,
)
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _schema
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
tool_schema = tool.tool_call_schema
if isinstance(tool_schema, dict):
return tool_schema
if hasattr(tool_schema, "model_json_schema"):
return tool_schema.model_json_schema()
else:
return tool_schema.schema()
def test_unnamed_decorator() -> None:
"""Test functionality with unnamed decorator."""
@@ -1721,7 +1737,7 @@ def test_tool_inherited_injected_arg() -> None:
"required": ["y", "x"],
}
# Should not include `y` since it's annotated as an injected tool arg
assert tool_.tool_call_schema.model_json_schema() == {
assert _get_tool_call_json_schema(tool_) == {
"title": "foo",
"description": "foo.",
"type": "object",
@@ -1840,12 +1856,7 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
"type": "object",
}
tool_schema = tool.tool_call_schema
tool_json_schema = (
tool_schema.model_json_schema()
if hasattr(tool_schema, "model_json_schema")
else tool_schema.schema()
)
tool_json_schema = _get_tool_call_json_schema(tool)
assert tool_json_schema == {
"description": "some description",
"properties": {
@@ -1892,7 +1903,7 @@ def test_args_schema_explicitly_typed() -> None:
"type": "object",
}
assert tool.tool_call_schema.model_json_schema() == {
assert _get_tool_call_json_schema(tool) == {
"description": "some description",
"properties": {
"a": {"title": "A", "type": "integer"},
@@ -1920,7 +1931,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
args_schema = foo_tool.args_schema
args_schema = cast(BaseModel, foo_tool.args_schema)
args_json_schema = (
args_schema.model_json_schema()
if hasattr(args_schema, "model_json_schema")
@@ -2210,7 +2221,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
"""Foo."""
return x
assert foo.tool_call_schema.model_json_schema() == {
assert _get_tool_call_json_schema(foo) == {
"description": "Foo.",
"properties": {
"x": {
@@ -2376,3 +2387,73 @@ def test_tool_mutate_input() -> None:
my_input = {"x": "hi"}
MyTool().invoke(my_input)
assert my_input == {"x": "hi"}
def test_structured_tool_args_schema_dict() -> None:
args_schema = {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "integer"},
},
"required": ["a", "b"],
"title": "add",
"type": "object",
}
tool = StructuredTool(
name="add",
description="add two numbers",
args_schema=args_schema,
func=lambda a, b: a + b,
)
assert tool.invoke({"a": 1, "b": 2}) == 3
assert tool.args_schema == args_schema
# test that the tool call schema is the same as the args schema
assert _get_tool_call_json_schema(tool) == args_schema
# test that the input schema is the same as the parent (Runnable) input schema
assert (
tool.get_input_schema().model_json_schema()
== create_model_v2(
tool.get_name("Input"),
root=tool.InputType,
module_name=tool.__class__.__module__,
).model_json_schema()
)
# test that args are extracted correctly
assert tool.args == {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "integer"},
}
def test_simple_tool_args_schema_dict() -> None:
args_schema = {
"properties": {
"a": {"title": "A", "type": "integer"},
},
"required": ["a"],
"title": "square",
"type": "object",
}
tool = Tool(
name="square",
description="square a number",
args_schema=args_schema,
func=lambda a: a * a,
)
assert tool.invoke({"a": 2}) == 4
assert tool.args_schema == args_schema
# test that the tool call schema is the same as the args schema
assert _get_tool_call_json_schema(tool) == args_schema
# test that the input schema is the same as the parent (Runnable) input schema
assert (
tool.get_input_schema().model_json_schema()
== create_model_v2(
tool.get_name("Input"),
root=tool.InputType,
module_name=tool.__class__.__module__,
).model_json_schema()
)
# test that args are extracted correctly
assert tool.args == {
"a": {"title": "A", "type": "integer"},
}

View File

@@ -127,6 +127,28 @@ def dummy_structured_tool() -> StructuredTool:
)
@pytest.fixture()
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
args_schema = {
"type": "object",
"properties": {
"arg1": {"type": "integer", "description": "foo"},
"arg2": {
"type": "string",
"enum": ["bar", "baz"],
"description": "one of 'bar', 'baz'",
},
},
"required": ["arg1", "arg2"],
}
return StructuredTool.from_function(
lambda x: None,
name="dummy_function",
description="Dummy function.",
args_schema=args_schema,
)
@pytest.fixture()
def dummy_pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801
@@ -293,6 +315,7 @@ def test_convert_to_openai_function(
function: Callable,
function_docstring_annotations: Callable,
dummy_structured_tool: StructuredTool,
dummy_structured_tool_args_schema_dict: StructuredTool,
dummy_tool: BaseTool,
json_schema: dict,
anthropic_tool: dict,
@@ -327,6 +350,7 @@ def test_convert_to_openai_function(
function,
function_docstring_annotations,
dummy_structured_tool,
dummy_structured_tool_args_schema_dict,
dummy_tool,
json_schema,
anthropic_tool,