mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
core[patch]: allow passing JSON schema as args_schema to tools (#29812)
This commit is contained in:
parent
5034a8dc5c
commit
d04fa1ae50
@ -317,6 +317,9 @@ class ToolException(Exception): # noqa: N818
|
||||
"""
|
||||
|
||||
|
||||
ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
@ -354,7 +357,7 @@ class ChildTool(BaseTool):
|
||||
You can provide few-shot examples as a part of the description.
|
||||
"""
|
||||
|
||||
args_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = Field(
|
||||
args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field(
|
||||
default=None, description="The tool schema."
|
||||
)
|
||||
"""Pydantic model class to validate and parse the tool's input arguments.
|
||||
@ -364,6 +367,8 @@ class ChildTool(BaseTool):
|
||||
- A subclass of pydantic.BaseModel.
|
||||
or
|
||||
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
|
||||
or
|
||||
- a JSON schema dict
|
||||
"""
|
||||
return_direct: bool = False
|
||||
"""Whether to return the tool's output directly.
|
||||
@ -423,10 +428,11 @@ class ChildTool(BaseTool):
|
||||
"args_schema" in kwargs
|
||||
and kwargs["args_schema"] is not None
|
||||
and not is_basemodel_subclass(kwargs["args_schema"])
|
||||
and not isinstance(kwargs["args_schema"], dict)
|
||||
):
|
||||
msg = (
|
||||
f"args_schema must be a subclass of pydantic BaseModel. "
|
||||
f"Got: {kwargs['args_schema']}."
|
||||
"args_schema must be a subclass of pydantic BaseModel or "
|
||||
f"a JSON schema dict. Got: {kwargs['args_schema']}."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
super().__init__(**kwargs)
|
||||
@ -443,10 +449,18 @@ class ChildTool(BaseTool):
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
return self.get_input_schema().model_json_schema()["properties"]
|
||||
if isinstance(self.args_schema, dict):
|
||||
json_schema = self.args_schema
|
||||
else:
|
||||
input_schema = self.get_input_schema()
|
||||
json_schema = input_schema.model_json_schema()
|
||||
return json_schema["properties"]
|
||||
|
||||
@property
|
||||
def tool_call_schema(self) -> type[BaseModel]:
|
||||
def tool_call_schema(self) -> ArgsSchema:
|
||||
if isinstance(self.args_schema, dict):
|
||||
return self.args_schema
|
||||
|
||||
full_schema = self.get_input_schema()
|
||||
fields = []
|
||||
for name, type_ in get_all_basemodel_annotations(full_schema).items():
|
||||
@ -470,6 +484,8 @@ class ChildTool(BaseTool):
|
||||
The input schema for the tool.
|
||||
"""
|
||||
if self.args_schema is not None:
|
||||
if isinstance(self.args_schema, dict):
|
||||
return super().get_input_schema(config)
|
||||
return self.args_schema
|
||||
else:
|
||||
return create_schema_from_function(self.name, self._run)
|
||||
@ -505,6 +521,12 @@ class ChildTool(BaseTool):
|
||||
input_args = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
if input_args is not None:
|
||||
if isinstance(input_args, dict):
|
||||
msg = (
|
||||
"String tool inputs are not allowed when "
|
||||
"using tools with JSON schema args_schema."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
key_ = next(iter(get_fields(input_args).keys()))
|
||||
if hasattr(input_args, "model_validate"):
|
||||
input_args.model_validate({key_: tool_input})
|
||||
@ -513,7 +535,9 @@ class ChildTool(BaseTool):
|
||||
return tool_input
|
||||
else:
|
||||
if input_args is not None:
|
||||
if issubclass(input_args, BaseModel):
|
||||
if isinstance(input_args, dict):
|
||||
return tool_input
|
||||
elif issubclass(input_args, BaseModel):
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if (
|
||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||
@ -605,7 +629,12 @@ class ChildTool(BaseTool):
|
||||
def _to_args_and_kwargs(
|
||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||
) -> tuple[tuple, dict]:
|
||||
if self.args_schema is not None and not get_fields(self.args_schema):
|
||||
if (
|
||||
self.args_schema is not None
|
||||
and isinstance(self.args_schema, type)
|
||||
and is_basemodel_subclass(self.args_schema)
|
||||
and not get_fields(self.args_schema)
|
||||
):
|
||||
# StructuredTool with no args
|
||||
return (), {}
|
||||
tool_input = self._parse_input(tool_input, tool_call_id)
|
||||
|
@ -9,8 +9,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
@ -18,6 +16,7 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||
from langchain_core.tools.base import (
|
||||
ArgsSchema,
|
||||
BaseTool,
|
||||
ToolException,
|
||||
_get_runnable_config_param,
|
||||
@ -57,7 +56,11 @@ class Tool(BaseTool):
|
||||
The input arguments for the tool.
|
||||
"""
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema.model_json_schema()["properties"]
|
||||
if isinstance(self.args_schema, dict):
|
||||
json_schema = self.args_schema
|
||||
else:
|
||||
json_schema = self.args_schema.model_json_schema()
|
||||
return json_schema["properties"]
|
||||
# For backwards compatibility, if the function signature is ambiguous,
|
||||
# assume it takes a single string input.
|
||||
return {"tool_input": {"type": "string"}}
|
||||
@ -132,7 +135,7 @@ class Tool(BaseTool):
|
||||
name: str, # We keep these required to support backwards compatibility
|
||||
description: str,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type[BaseModel]] = None,
|
||||
args_schema: Optional[ArgsSchema] = None,
|
||||
coroutine: Optional[
|
||||
Callable[..., Awaitable[Any]]
|
||||
] = None, # This is last for compatibility, but should be after func
|
||||
|
@ -12,7 +12,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, SkipValidation
|
||||
from pydantic import Field, SkipValidation
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
@ -22,18 +22,18 @@ from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||
from langchain_core.tools.base import (
|
||||
FILTERED_ARGS,
|
||||
ArgsSchema,
|
||||
BaseTool,
|
||||
_get_runnable_config_param,
|
||||
create_schema_from_function,
|
||||
)
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
|
||||
class StructuredTool(BaseTool):
|
||||
"""Tool that can operate on any number of inputs."""
|
||||
|
||||
description: str = ""
|
||||
args_schema: Annotated[TypeBaseModel, SkipValidation()] = Field(
|
||||
args_schema: Annotated[ArgsSchema, SkipValidation()] = Field(
|
||||
..., description="The tool schema."
|
||||
)
|
||||
"""The input arguments' schema."""
|
||||
@ -62,7 +62,12 @@ class StructuredTool(BaseTool):
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
"""The tool's input arguments."""
|
||||
return self.args_schema.model_json_schema()["properties"]
|
||||
if isinstance(self.args_schema, dict):
|
||||
json_schema = self.args_schema
|
||||
else:
|
||||
input_schema = self.get_input_schema()
|
||||
json_schema = input_schema.model_json_schema()
|
||||
return json_schema["properties"]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@ -110,7 +115,7 @@ class StructuredTool(BaseTool):
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type[BaseModel]] = None,
|
||||
args_schema: Optional[ArgsSchema] = None,
|
||||
infer_schema: bool = True,
|
||||
*,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
|
@ -20,6 +20,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
|
||||
|
||||
from langchain_core._api import beta, deprecated
|
||||
@ -75,6 +76,40 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
||||
return new_kv
|
||||
|
||||
|
||||
def _convert_json_schema_to_openai_function(
|
||||
schema: dict,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
rm_titles: bool = True,
|
||||
) -> FunctionDescription:
|
||||
"""Converts a Pydantic model to a function description for the OpenAI API.
|
||||
|
||||
Args:
|
||||
schema: The JSON schema to convert.
|
||||
name: The name of the function. If not provided, the title of the schema will be
|
||||
used.
|
||||
description: The description of the function. If not provided, the description
|
||||
of the schema will be used.
|
||||
rm_titles: Whether to remove titles from the schema. Defaults to True.
|
||||
|
||||
Returns:
|
||||
The function description.
|
||||
"""
|
||||
schema = dereference_refs(schema)
|
||||
if "definitions" in schema: # pydantic 1
|
||||
schema.pop("definitions", None)
|
||||
if "$defs" in schema: # pydantic 2
|
||||
schema.pop("$defs", None)
|
||||
title = schema.pop("title", "")
|
||||
default_description = schema.pop("description", "")
|
||||
return {
|
||||
"name": name or title,
|
||||
"description": description or default_description,
|
||||
"parameters": _rm_titles(schema) if rm_titles else schema,
|
||||
}
|
||||
|
||||
|
||||
def _convert_pydantic_to_openai_function(
|
||||
model: type,
|
||||
*,
|
||||
@ -102,18 +137,9 @@ def _convert_pydantic_to_openai_function(
|
||||
else:
|
||||
msg = "Model must be a Pydantic model."
|
||||
raise TypeError(msg)
|
||||
schema = dereference_refs(schema)
|
||||
if "definitions" in schema: # pydantic 1
|
||||
schema.pop("definitions", None)
|
||||
if "$defs" in schema: # pydantic 2
|
||||
schema.pop("$defs", None)
|
||||
title = schema.pop("title", "")
|
||||
default_description = schema.pop("description", "")
|
||||
return {
|
||||
"name": name or title,
|
||||
"description": description or default_description,
|
||||
"parameters": _rm_titles(schema) if rm_titles else schema,
|
||||
}
|
||||
return _convert_json_schema_to_openai_function(
|
||||
schema, name=name, description=description, rm_titles=rm_titles
|
||||
)
|
||||
|
||||
|
||||
convert_pydantic_to_openai_function = deprecated(
|
||||
@ -289,9 +315,20 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
|
||||
is_simple_oai_tool = isinstance(tool, simple.Tool) and not tool.args_schema
|
||||
if tool.tool_call_schema and not is_simple_oai_tool:
|
||||
return _convert_pydantic_to_openai_function(
|
||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||
)
|
||||
if isinstance(tool.tool_call_schema, dict):
|
||||
return _convert_json_schema_to_openai_function(
|
||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||
)
|
||||
elif issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)):
|
||||
return _convert_pydantic_to_openai_function(
|
||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||
)
|
||||
else:
|
||||
error_msg = (
|
||||
f"Unsupported tool call schema: {tool.tool_call_schema}. "
|
||||
"Tool call schema must be a JSON schema dict or a Pydantic model."
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
return {
|
||||
"name": tool.name,
|
||||
|
@ -17,6 +17,7 @@ from typing import (
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import pytest
|
||||
@ -59,11 +60,26 @@ from langchain_core.tools.base import (
|
||||
get_all_basemodel_annotations,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
|
||||
from langchain_core.utils.pydantic import (
|
||||
PYDANTIC_MAJOR_VERSION,
|
||||
_create_subset_model,
|
||||
create_model_v2,
|
||||
)
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
|
||||
tool_schema = tool.tool_call_schema
|
||||
if isinstance(tool_schema, dict):
|
||||
return tool_schema
|
||||
|
||||
if hasattr(tool_schema, "model_json_schema"):
|
||||
return tool_schema.model_json_schema()
|
||||
else:
|
||||
return tool_schema.schema()
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
"""Test functionality with unnamed decorator."""
|
||||
|
||||
@ -1721,7 +1737,7 @@ def test_tool_inherited_injected_arg() -> None:
|
||||
"required": ["y", "x"],
|
||||
}
|
||||
# Should not include `y` since it's annotated as an injected tool arg
|
||||
assert tool_.tool_call_schema.model_json_schema() == {
|
||||
assert _get_tool_call_json_schema(tool_) == {
|
||||
"title": "foo",
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
@ -1840,12 +1856,7 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
tool_schema = tool.tool_call_schema
|
||||
tool_json_schema = (
|
||||
tool_schema.model_json_schema()
|
||||
if hasattr(tool_schema, "model_json_schema")
|
||||
else tool_schema.schema()
|
||||
)
|
||||
tool_json_schema = _get_tool_call_json_schema(tool)
|
||||
assert tool_json_schema == {
|
||||
"description": "some description",
|
||||
"properties": {
|
||||
@ -1892,7 +1903,7 @@ def test_args_schema_explicitly_typed() -> None:
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
assert tool.tool_call_schema.model_json_schema() == {
|
||||
assert _get_tool_call_json_schema(tool) == {
|
||||
"description": "some description",
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
@ -1920,7 +1931,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
||||
|
||||
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
||||
|
||||
args_schema = foo_tool.args_schema
|
||||
args_schema = cast(BaseModel, foo_tool.args_schema)
|
||||
args_json_schema = (
|
||||
args_schema.model_json_schema()
|
||||
if hasattr(args_schema, "model_json_schema")
|
||||
@ -2210,7 +2221,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||
"""Foo."""
|
||||
return x
|
||||
|
||||
assert foo.tool_call_schema.model_json_schema() == {
|
||||
assert _get_tool_call_json_schema(foo) == {
|
||||
"description": "Foo.",
|
||||
"properties": {
|
||||
"x": {
|
||||
@ -2376,3 +2387,73 @@ def test_tool_mutate_input() -> None:
|
||||
my_input = {"x": "hi"}
|
||||
MyTool().invoke(my_input)
|
||||
assert my_input == {"x": "hi"}
|
||||
|
||||
|
||||
def test_structured_tool_args_schema_dict() -> None:
|
||||
args_schema = {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "integer"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": "add",
|
||||
"type": "object",
|
||||
}
|
||||
tool = StructuredTool(
|
||||
name="add",
|
||||
description="add two numbers",
|
||||
args_schema=args_schema,
|
||||
func=lambda a, b: a + b,
|
||||
)
|
||||
assert tool.invoke({"a": 1, "b": 2}) == 3
|
||||
assert tool.args_schema == args_schema
|
||||
# test that the tool call schema is the same as the args schema
|
||||
assert _get_tool_call_json_schema(tool) == args_schema
|
||||
# test that the input schema is the same as the parent (Runnable) input schema
|
||||
assert (
|
||||
tool.get_input_schema().model_json_schema()
|
||||
== create_model_v2(
|
||||
tool.get_name("Input"),
|
||||
root=tool.InputType,
|
||||
module_name=tool.__class__.__module__,
|
||||
).model_json_schema()
|
||||
)
|
||||
# test that args are extracted correctly
|
||||
assert tool.args == {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "integer"},
|
||||
}
|
||||
|
||||
|
||||
def test_simple_tool_args_schema_dict() -> None:
|
||||
args_schema = {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
},
|
||||
"required": ["a"],
|
||||
"title": "square",
|
||||
"type": "object",
|
||||
}
|
||||
tool = Tool(
|
||||
name="square",
|
||||
description="square a number",
|
||||
args_schema=args_schema,
|
||||
func=lambda a: a * a,
|
||||
)
|
||||
assert tool.invoke({"a": 2}) == 4
|
||||
assert tool.args_schema == args_schema
|
||||
# test that the tool call schema is the same as the args schema
|
||||
assert _get_tool_call_json_schema(tool) == args_schema
|
||||
# test that the input schema is the same as the parent (Runnable) input schema
|
||||
assert (
|
||||
tool.get_input_schema().model_json_schema()
|
||||
== create_model_v2(
|
||||
tool.get_name("Input"),
|
||||
root=tool.InputType,
|
||||
module_name=tool.__class__.__module__,
|
||||
).model_json_schema()
|
||||
)
|
||||
# test that args are extracted correctly
|
||||
assert tool.args == {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
}
|
||||
|
@ -127,6 +127,28 @@ def dummy_structured_tool() -> StructuredTool:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
||||
args_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"type": "integer", "description": "foo"},
|
||||
"arg2": {
|
||||
"type": "string",
|
||||
"enum": ["bar", "baz"],
|
||||
"description": "one of 'bar', 'baz'",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
}
|
||||
return StructuredTool.from_function(
|
||||
lambda x: None,
|
||||
name="dummy_function",
|
||||
description="Dummy function.",
|
||||
args_schema=args_schema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel): # noqa: N801
|
||||
@ -293,6 +315,7 @@ def test_convert_to_openai_function(
|
||||
function: Callable,
|
||||
function_docstring_annotations: Callable,
|
||||
dummy_structured_tool: StructuredTool,
|
||||
dummy_structured_tool_args_schema_dict: StructuredTool,
|
||||
dummy_tool: BaseTool,
|
||||
json_schema: dict,
|
||||
anthropic_tool: dict,
|
||||
@ -327,6 +350,7 @@ def test_convert_to_openai_function(
|
||||
function,
|
||||
function_docstring_annotations,
|
||||
dummy_structured_tool,
|
||||
dummy_structured_tool_args_schema_dict,
|
||||
dummy_tool,
|
||||
json_schema,
|
||||
anthropic_tool,
|
||||
|
Loading…
Reference in New Issue
Block a user