mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-06 13:33:37 +00:00
core[patch]: allow passing JSON schema as args_schema to tools (#29812)
This commit is contained in:
@@ -17,6 +17,7 @@ from typing import (
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import pytest
|
||||
@@ -59,11 +60,26 @@ from langchain_core.tools.base import (
|
||||
get_all_basemodel_annotations,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
|
||||
from langchain_core.utils.pydantic import (
|
||||
PYDANTIC_MAJOR_VERSION,
|
||||
_create_subset_model,
|
||||
create_model_v2,
|
||||
)
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
|
||||
tool_schema = tool.tool_call_schema
|
||||
if isinstance(tool_schema, dict):
|
||||
return tool_schema
|
||||
|
||||
if hasattr(tool_schema, "model_json_schema"):
|
||||
return tool_schema.model_json_schema()
|
||||
else:
|
||||
return tool_schema.schema()
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
"""Test functionality with unnamed decorator."""
|
||||
|
||||
@@ -1721,7 +1737,7 @@ def test_tool_inherited_injected_arg() -> None:
|
||||
"required": ["y", "x"],
|
||||
}
|
||||
# Should not include `y` since it's annotated as an injected tool arg
|
||||
assert tool_.tool_call_schema.model_json_schema() == {
|
||||
assert _get_tool_call_json_schema(tool_) == {
|
||||
"title": "foo",
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
@@ -1840,12 +1856,7 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
tool_schema = tool.tool_call_schema
|
||||
tool_json_schema = (
|
||||
tool_schema.model_json_schema()
|
||||
if hasattr(tool_schema, "model_json_schema")
|
||||
else tool_schema.schema()
|
||||
)
|
||||
tool_json_schema = _get_tool_call_json_schema(tool)
|
||||
assert tool_json_schema == {
|
||||
"description": "some description",
|
||||
"properties": {
|
||||
@@ -1892,7 +1903,7 @@ def test_args_schema_explicitly_typed() -> None:
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
assert tool.tool_call_schema.model_json_schema() == {
|
||||
assert _get_tool_call_json_schema(tool) == {
|
||||
"description": "some description",
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
@@ -1920,7 +1931,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
||||
|
||||
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
||||
|
||||
args_schema = foo_tool.args_schema
|
||||
args_schema = cast(BaseModel, foo_tool.args_schema)
|
||||
args_json_schema = (
|
||||
args_schema.model_json_schema()
|
||||
if hasattr(args_schema, "model_json_schema")
|
||||
@@ -2210,7 +2221,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||
"""Foo."""
|
||||
return x
|
||||
|
||||
assert foo.tool_call_schema.model_json_schema() == {
|
||||
assert _get_tool_call_json_schema(foo) == {
|
||||
"description": "Foo.",
|
||||
"properties": {
|
||||
"x": {
|
||||
@@ -2376,3 +2387,73 @@ def test_tool_mutate_input() -> None:
|
||||
my_input = {"x": "hi"}
|
||||
MyTool().invoke(my_input)
|
||||
assert my_input == {"x": "hi"}
|
||||
|
||||
|
||||
def test_structured_tool_args_schema_dict() -> None:
|
||||
args_schema = {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "integer"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": "add",
|
||||
"type": "object",
|
||||
}
|
||||
tool = StructuredTool(
|
||||
name="add",
|
||||
description="add two numbers",
|
||||
args_schema=args_schema,
|
||||
func=lambda a, b: a + b,
|
||||
)
|
||||
assert tool.invoke({"a": 1, "b": 2}) == 3
|
||||
assert tool.args_schema == args_schema
|
||||
# test that the tool call schema is the same as the args schema
|
||||
assert _get_tool_call_json_schema(tool) == args_schema
|
||||
# test that the input schema is the same as the parent (Runnable) input schema
|
||||
assert (
|
||||
tool.get_input_schema().model_json_schema()
|
||||
== create_model_v2(
|
||||
tool.get_name("Input"),
|
||||
root=tool.InputType,
|
||||
module_name=tool.__class__.__module__,
|
||||
).model_json_schema()
|
||||
)
|
||||
# test that args are extracted correctly
|
||||
assert tool.args == {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "integer"},
|
||||
}
|
||||
|
||||
|
||||
def test_simple_tool_args_schema_dict() -> None:
|
||||
args_schema = {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
},
|
||||
"required": ["a"],
|
||||
"title": "square",
|
||||
"type": "object",
|
||||
}
|
||||
tool = Tool(
|
||||
name="square",
|
||||
description="square a number",
|
||||
args_schema=args_schema,
|
||||
func=lambda a: a * a,
|
||||
)
|
||||
assert tool.invoke({"a": 2}) == 4
|
||||
assert tool.args_schema == args_schema
|
||||
# test that the tool call schema is the same as the args schema
|
||||
assert _get_tool_call_json_schema(tool) == args_schema
|
||||
# test that the input schema is the same as the parent (Runnable) input schema
|
||||
assert (
|
||||
tool.get_input_schema().model_json_schema()
|
||||
== create_model_v2(
|
||||
tool.get_name("Input"),
|
||||
root=tool.InputType,
|
||||
module_name=tool.__class__.__module__,
|
||||
).model_json_schema()
|
||||
)
|
||||
# test that args are extracted correctly
|
||||
assert tool.args == {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
}
|
||||
|
@@ -127,6 +127,28 @@ def dummy_structured_tool() -> StructuredTool:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
||||
args_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"type": "integer", "description": "foo"},
|
||||
"arg2": {
|
||||
"type": "string",
|
||||
"enum": ["bar", "baz"],
|
||||
"description": "one of 'bar', 'baz'",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
}
|
||||
return StructuredTool.from_function(
|
||||
lambda x: None,
|
||||
name="dummy_function",
|
||||
description="Dummy function.",
|
||||
args_schema=args_schema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel): # noqa: N801
|
||||
@@ -293,6 +315,7 @@ def test_convert_to_openai_function(
|
||||
function: Callable,
|
||||
function_docstring_annotations: Callable,
|
||||
dummy_structured_tool: StructuredTool,
|
||||
dummy_structured_tool_args_schema_dict: StructuredTool,
|
||||
dummy_tool: BaseTool,
|
||||
json_schema: dict,
|
||||
anthropic_tool: dict,
|
||||
@@ -327,6 +350,7 @@ def test_convert_to_openai_function(
|
||||
function,
|
||||
function_docstring_annotations,
|
||||
dummy_structured_tool,
|
||||
dummy_structured_tool_args_schema_dict,
|
||||
dummy_tool,
|
||||
json_schema,
|
||||
anthropic_tool,
|
||||
|
Reference in New Issue
Block a user