mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
core[patch]: allow passing JSON schema as args_schema to tools (#29812)
This commit is contained in:
parent
5034a8dc5c
commit
d04fa1ae50
@ -317,6 +317,9 @@ class ToolException(Exception): # noqa: N818
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
||||||
"""Interface LangChain tools must implement."""
|
"""Interface LangChain tools must implement."""
|
||||||
|
|
||||||
@ -354,7 +357,7 @@ class ChildTool(BaseTool):
|
|||||||
You can provide few-shot examples as a part of the description.
|
You can provide few-shot examples as a part of the description.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
args_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = Field(
|
args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field(
|
||||||
default=None, description="The tool schema."
|
default=None, description="The tool schema."
|
||||||
)
|
)
|
||||||
"""Pydantic model class to validate and parse the tool's input arguments.
|
"""Pydantic model class to validate and parse the tool's input arguments.
|
||||||
@ -364,6 +367,8 @@ class ChildTool(BaseTool):
|
|||||||
- A subclass of pydantic.BaseModel.
|
- A subclass of pydantic.BaseModel.
|
||||||
or
|
or
|
||||||
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
|
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
|
||||||
|
or
|
||||||
|
- a JSON schema dict
|
||||||
"""
|
"""
|
||||||
return_direct: bool = False
|
return_direct: bool = False
|
||||||
"""Whether to return the tool's output directly.
|
"""Whether to return the tool's output directly.
|
||||||
@ -423,10 +428,11 @@ class ChildTool(BaseTool):
|
|||||||
"args_schema" in kwargs
|
"args_schema" in kwargs
|
||||||
and kwargs["args_schema"] is not None
|
and kwargs["args_schema"] is not None
|
||||||
and not is_basemodel_subclass(kwargs["args_schema"])
|
and not is_basemodel_subclass(kwargs["args_schema"])
|
||||||
|
and not isinstance(kwargs["args_schema"], dict)
|
||||||
):
|
):
|
||||||
msg = (
|
msg = (
|
||||||
f"args_schema must be a subclass of pydantic BaseModel. "
|
"args_schema must be a subclass of pydantic BaseModel or "
|
||||||
f"Got: {kwargs['args_schema']}."
|
f"a JSON schema dict. Got: {kwargs['args_schema']}."
|
||||||
)
|
)
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -443,10 +449,18 @@ class ChildTool(BaseTool):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> dict:
|
def args(self) -> dict:
|
||||||
return self.get_input_schema().model_json_schema()["properties"]
|
if isinstance(self.args_schema, dict):
|
||||||
|
json_schema = self.args_schema
|
||||||
|
else:
|
||||||
|
input_schema = self.get_input_schema()
|
||||||
|
json_schema = input_schema.model_json_schema()
|
||||||
|
return json_schema["properties"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_call_schema(self) -> type[BaseModel]:
|
def tool_call_schema(self) -> ArgsSchema:
|
||||||
|
if isinstance(self.args_schema, dict):
|
||||||
|
return self.args_schema
|
||||||
|
|
||||||
full_schema = self.get_input_schema()
|
full_schema = self.get_input_schema()
|
||||||
fields = []
|
fields = []
|
||||||
for name, type_ in get_all_basemodel_annotations(full_schema).items():
|
for name, type_ in get_all_basemodel_annotations(full_schema).items():
|
||||||
@ -470,6 +484,8 @@ class ChildTool(BaseTool):
|
|||||||
The input schema for the tool.
|
The input schema for the tool.
|
||||||
"""
|
"""
|
||||||
if self.args_schema is not None:
|
if self.args_schema is not None:
|
||||||
|
if isinstance(self.args_schema, dict):
|
||||||
|
return super().get_input_schema(config)
|
||||||
return self.args_schema
|
return self.args_schema
|
||||||
else:
|
else:
|
||||||
return create_schema_from_function(self.name, self._run)
|
return create_schema_from_function(self.name, self._run)
|
||||||
@ -505,6 +521,12 @@ class ChildTool(BaseTool):
|
|||||||
input_args = self.args_schema
|
input_args = self.args_schema
|
||||||
if isinstance(tool_input, str):
|
if isinstance(tool_input, str):
|
||||||
if input_args is not None:
|
if input_args is not None:
|
||||||
|
if isinstance(input_args, dict):
|
||||||
|
msg = (
|
||||||
|
"String tool inputs are not allowed when "
|
||||||
|
"using tools with JSON schema args_schema."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
key_ = next(iter(get_fields(input_args).keys()))
|
key_ = next(iter(get_fields(input_args).keys()))
|
||||||
if hasattr(input_args, "model_validate"):
|
if hasattr(input_args, "model_validate"):
|
||||||
input_args.model_validate({key_: tool_input})
|
input_args.model_validate({key_: tool_input})
|
||||||
@ -513,7 +535,9 @@ class ChildTool(BaseTool):
|
|||||||
return tool_input
|
return tool_input
|
||||||
else:
|
else:
|
||||||
if input_args is not None:
|
if input_args is not None:
|
||||||
if issubclass(input_args, BaseModel):
|
if isinstance(input_args, dict):
|
||||||
|
return tool_input
|
||||||
|
elif issubclass(input_args, BaseModel):
|
||||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||||
if (
|
if (
|
||||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||||
@ -605,7 +629,12 @@ class ChildTool(BaseTool):
|
|||||||
def _to_args_and_kwargs(
|
def _to_args_and_kwargs(
|
||||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||||
) -> tuple[tuple, dict]:
|
) -> tuple[tuple, dict]:
|
||||||
if self.args_schema is not None and not get_fields(self.args_schema):
|
if (
|
||||||
|
self.args_schema is not None
|
||||||
|
and isinstance(self.args_schema, type)
|
||||||
|
and is_basemodel_subclass(self.args_schema)
|
||||||
|
and not get_fields(self.args_schema)
|
||||||
|
):
|
||||||
# StructuredTool with no args
|
# StructuredTool with no args
|
||||||
return (), {}
|
return (), {}
|
||||||
tool_input = self._parse_input(tool_input, tool_call_id)
|
tool_input = self._parse_input(tool_input, tool_call_id)
|
||||||
|
@ -9,8 +9,6 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
@ -18,6 +16,7 @@ from langchain_core.callbacks import (
|
|||||||
from langchain_core.messages import ToolCall
|
from langchain_core.messages import ToolCall
|
||||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||||
from langchain_core.tools.base import (
|
from langchain_core.tools.base import (
|
||||||
|
ArgsSchema,
|
||||||
BaseTool,
|
BaseTool,
|
||||||
ToolException,
|
ToolException,
|
||||||
_get_runnable_config_param,
|
_get_runnable_config_param,
|
||||||
@ -57,7 +56,11 @@ class Tool(BaseTool):
|
|||||||
The input arguments for the tool.
|
The input arguments for the tool.
|
||||||
"""
|
"""
|
||||||
if self.args_schema is not None:
|
if self.args_schema is not None:
|
||||||
return self.args_schema.model_json_schema()["properties"]
|
if isinstance(self.args_schema, dict):
|
||||||
|
json_schema = self.args_schema
|
||||||
|
else:
|
||||||
|
json_schema = self.args_schema.model_json_schema()
|
||||||
|
return json_schema["properties"]
|
||||||
# For backwards compatibility, if the function signature is ambiguous,
|
# For backwards compatibility, if the function signature is ambiguous,
|
||||||
# assume it takes a single string input.
|
# assume it takes a single string input.
|
||||||
return {"tool_input": {"type": "string"}}
|
return {"tool_input": {"type": "string"}}
|
||||||
@ -132,7 +135,7 @@ class Tool(BaseTool):
|
|||||||
name: str, # We keep these required to support backwards compatibility
|
name: str, # We keep these required to support backwards compatibility
|
||||||
description: str,
|
description: str,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[type[BaseModel]] = None,
|
args_schema: Optional[ArgsSchema] = None,
|
||||||
coroutine: Optional[
|
coroutine: Optional[
|
||||||
Callable[..., Awaitable[Any]]
|
Callable[..., Awaitable[Any]]
|
||||||
] = None, # This is last for compatibility, but should be after func
|
] = None, # This is last for compatibility, but should be after func
|
||||||
|
@ -12,7 +12,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SkipValidation
|
from pydantic import Field, SkipValidation
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
@ -22,18 +22,18 @@ from langchain_core.messages import ToolCall
|
|||||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||||
from langchain_core.tools.base import (
|
from langchain_core.tools.base import (
|
||||||
FILTERED_ARGS,
|
FILTERED_ARGS,
|
||||||
|
ArgsSchema,
|
||||||
BaseTool,
|
BaseTool,
|
||||||
_get_runnable_config_param,
|
_get_runnable_config_param,
|
||||||
create_schema_from_function,
|
create_schema_from_function,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import TypeBaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class StructuredTool(BaseTool):
|
class StructuredTool(BaseTool):
|
||||||
"""Tool that can operate on any number of inputs."""
|
"""Tool that can operate on any number of inputs."""
|
||||||
|
|
||||||
description: str = ""
|
description: str = ""
|
||||||
args_schema: Annotated[TypeBaseModel, SkipValidation()] = Field(
|
args_schema: Annotated[ArgsSchema, SkipValidation()] = Field(
|
||||||
..., description="The tool schema."
|
..., description="The tool schema."
|
||||||
)
|
)
|
||||||
"""The input arguments' schema."""
|
"""The input arguments' schema."""
|
||||||
@ -62,7 +62,12 @@ class StructuredTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def args(self) -> dict:
|
def args(self) -> dict:
|
||||||
"""The tool's input arguments."""
|
"""The tool's input arguments."""
|
||||||
return self.args_schema.model_json_schema()["properties"]
|
if isinstance(self.args_schema, dict):
|
||||||
|
json_schema = self.args_schema
|
||||||
|
else:
|
||||||
|
input_schema = self.get_input_schema()
|
||||||
|
json_schema = input_schema.model_json_schema()
|
||||||
|
return json_schema["properties"]
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
@ -110,7 +115,7 @@ class StructuredTool(BaseTool):
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[type[BaseModel]] = None,
|
args_schema: Optional[ArgsSchema] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
*,
|
*,
|
||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
|
@ -20,6 +20,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
|
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
|
||||||
|
|
||||||
from langchain_core._api import beta, deprecated
|
from langchain_core._api import beta, deprecated
|
||||||
@ -75,6 +76,40 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
|||||||
return new_kv
|
return new_kv
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_json_schema_to_openai_function(
|
||||||
|
schema: dict,
|
||||||
|
*,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
rm_titles: bool = True,
|
||||||
|
) -> FunctionDescription:
|
||||||
|
"""Converts a Pydantic model to a function description for the OpenAI API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: The JSON schema to convert.
|
||||||
|
name: The name of the function. If not provided, the title of the schema will be
|
||||||
|
used.
|
||||||
|
description: The description of the function. If not provided, the description
|
||||||
|
of the schema will be used.
|
||||||
|
rm_titles: Whether to remove titles from the schema. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The function description.
|
||||||
|
"""
|
||||||
|
schema = dereference_refs(schema)
|
||||||
|
if "definitions" in schema: # pydantic 1
|
||||||
|
schema.pop("definitions", None)
|
||||||
|
if "$defs" in schema: # pydantic 2
|
||||||
|
schema.pop("$defs", None)
|
||||||
|
title = schema.pop("title", "")
|
||||||
|
default_description = schema.pop("description", "")
|
||||||
|
return {
|
||||||
|
"name": name or title,
|
||||||
|
"description": description or default_description,
|
||||||
|
"parameters": _rm_titles(schema) if rm_titles else schema,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _convert_pydantic_to_openai_function(
|
def _convert_pydantic_to_openai_function(
|
||||||
model: type,
|
model: type,
|
||||||
*,
|
*,
|
||||||
@ -102,18 +137,9 @@ def _convert_pydantic_to_openai_function(
|
|||||||
else:
|
else:
|
||||||
msg = "Model must be a Pydantic model."
|
msg = "Model must be a Pydantic model."
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
schema = dereference_refs(schema)
|
return _convert_json_schema_to_openai_function(
|
||||||
if "definitions" in schema: # pydantic 1
|
schema, name=name, description=description, rm_titles=rm_titles
|
||||||
schema.pop("definitions", None)
|
)
|
||||||
if "$defs" in schema: # pydantic 2
|
|
||||||
schema.pop("$defs", None)
|
|
||||||
title = schema.pop("title", "")
|
|
||||||
default_description = schema.pop("description", "")
|
|
||||||
return {
|
|
||||||
"name": name or title,
|
|
||||||
"description": description or default_description,
|
|
||||||
"parameters": _rm_titles(schema) if rm_titles else schema,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
convert_pydantic_to_openai_function = deprecated(
|
convert_pydantic_to_openai_function = deprecated(
|
||||||
@ -289,9 +315,20 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
|||||||
|
|
||||||
is_simple_oai_tool = isinstance(tool, simple.Tool) and not tool.args_schema
|
is_simple_oai_tool = isinstance(tool, simple.Tool) and not tool.args_schema
|
||||||
if tool.tool_call_schema and not is_simple_oai_tool:
|
if tool.tool_call_schema and not is_simple_oai_tool:
|
||||||
return _convert_pydantic_to_openai_function(
|
if isinstance(tool.tool_call_schema, dict):
|
||||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
return _convert_json_schema_to_openai_function(
|
||||||
)
|
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||||
|
)
|
||||||
|
elif issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)):
|
||||||
|
return _convert_pydantic_to_openai_function(
|
||||||
|
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_msg = (
|
||||||
|
f"Unsupported tool call schema: {tool.tool_call_schema}. "
|
||||||
|
"Tool call schema must be a JSON schema dict or a Pydantic model."
|
||||||
|
)
|
||||||
|
raise ValueError(error_msg)
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
|
@ -17,6 +17,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -59,11 +60,26 @@ from langchain_core.tools.base import (
|
|||||||
get_all_basemodel_annotations,
|
get_all_basemodel_annotations,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
|
from langchain_core.utils.pydantic import (
|
||||||
|
PYDANTIC_MAJOR_VERSION,
|
||||||
|
_create_subset_model,
|
||||||
|
create_model_v2,
|
||||||
|
)
|
||||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||||
from tests.unit_tests.pydantic_utils import _schema
|
from tests.unit_tests.pydantic_utils import _schema
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
|
||||||
|
tool_schema = tool.tool_call_schema
|
||||||
|
if isinstance(tool_schema, dict):
|
||||||
|
return tool_schema
|
||||||
|
|
||||||
|
if hasattr(tool_schema, "model_json_schema"):
|
||||||
|
return tool_schema.model_json_schema()
|
||||||
|
else:
|
||||||
|
return tool_schema.schema()
|
||||||
|
|
||||||
|
|
||||||
def test_unnamed_decorator() -> None:
|
def test_unnamed_decorator() -> None:
|
||||||
"""Test functionality with unnamed decorator."""
|
"""Test functionality with unnamed decorator."""
|
||||||
|
|
||||||
@ -1721,7 +1737,7 @@ def test_tool_inherited_injected_arg() -> None:
|
|||||||
"required": ["y", "x"],
|
"required": ["y", "x"],
|
||||||
}
|
}
|
||||||
# Should not include `y` since it's annotated as an injected tool arg
|
# Should not include `y` since it's annotated as an injected tool arg
|
||||||
assert tool_.tool_call_schema.model_json_schema() == {
|
assert _get_tool_call_json_schema(tool_) == {
|
||||||
"title": "foo",
|
"title": "foo",
|
||||||
"description": "foo.",
|
"description": "foo.",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -1840,12 +1856,7 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
}
|
}
|
||||||
|
|
||||||
tool_schema = tool.tool_call_schema
|
tool_json_schema = _get_tool_call_json_schema(tool)
|
||||||
tool_json_schema = (
|
|
||||||
tool_schema.model_json_schema()
|
|
||||||
if hasattr(tool_schema, "model_json_schema")
|
|
||||||
else tool_schema.schema()
|
|
||||||
)
|
|
||||||
assert tool_json_schema == {
|
assert tool_json_schema == {
|
||||||
"description": "some description",
|
"description": "some description",
|
||||||
"properties": {
|
"properties": {
|
||||||
@ -1892,7 +1903,7 @@ def test_args_schema_explicitly_typed() -> None:
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert tool.tool_call_schema.model_json_schema() == {
|
assert _get_tool_call_json_schema(tool) == {
|
||||||
"description": "some description",
|
"description": "some description",
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {"title": "A", "type": "integer"},
|
"a": {"title": "A", "type": "integer"},
|
||||||
@ -1920,7 +1931,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
|
|||||||
|
|
||||||
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
||||||
|
|
||||||
args_schema = foo_tool.args_schema
|
args_schema = cast(BaseModel, foo_tool.args_schema)
|
||||||
args_json_schema = (
|
args_json_schema = (
|
||||||
args_schema.model_json_schema()
|
args_schema.model_json_schema()
|
||||||
if hasattr(args_schema, "model_json_schema")
|
if hasattr(args_schema, "model_json_schema")
|
||||||
@ -2210,7 +2221,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
|||||||
"""Foo."""
|
"""Foo."""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
assert foo.tool_call_schema.model_json_schema() == {
|
assert _get_tool_call_json_schema(foo) == {
|
||||||
"description": "Foo.",
|
"description": "Foo.",
|
||||||
"properties": {
|
"properties": {
|
||||||
"x": {
|
"x": {
|
||||||
@ -2376,3 +2387,73 @@ def test_tool_mutate_input() -> None:
|
|||||||
my_input = {"x": "hi"}
|
my_input = {"x": "hi"}
|
||||||
MyTool().invoke(my_input)
|
MyTool().invoke(my_input)
|
||||||
assert my_input == {"x": "hi"}
|
assert my_input == {"x": "hi"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_args_schema_dict() -> None:
|
||||||
|
args_schema = {
|
||||||
|
"properties": {
|
||||||
|
"a": {"title": "A", "type": "integer"},
|
||||||
|
"b": {"title": "B", "type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
"title": "add",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="add",
|
||||||
|
description="add two numbers",
|
||||||
|
args_schema=args_schema,
|
||||||
|
func=lambda a, b: a + b,
|
||||||
|
)
|
||||||
|
assert tool.invoke({"a": 1, "b": 2}) == 3
|
||||||
|
assert tool.args_schema == args_schema
|
||||||
|
# test that the tool call schema is the same as the args schema
|
||||||
|
assert _get_tool_call_json_schema(tool) == args_schema
|
||||||
|
# test that the input schema is the same as the parent (Runnable) input schema
|
||||||
|
assert (
|
||||||
|
tool.get_input_schema().model_json_schema()
|
||||||
|
== create_model_v2(
|
||||||
|
tool.get_name("Input"),
|
||||||
|
root=tool.InputType,
|
||||||
|
module_name=tool.__class__.__module__,
|
||||||
|
).model_json_schema()
|
||||||
|
)
|
||||||
|
# test that args are extracted correctly
|
||||||
|
assert tool.args == {
|
||||||
|
"a": {"title": "A", "type": "integer"},
|
||||||
|
"b": {"title": "B", "type": "integer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_tool_args_schema_dict() -> None:
|
||||||
|
args_schema = {
|
||||||
|
"properties": {
|
||||||
|
"a": {"title": "A", "type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["a"],
|
||||||
|
"title": "square",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
tool = Tool(
|
||||||
|
name="square",
|
||||||
|
description="square a number",
|
||||||
|
args_schema=args_schema,
|
||||||
|
func=lambda a: a * a,
|
||||||
|
)
|
||||||
|
assert tool.invoke({"a": 2}) == 4
|
||||||
|
assert tool.args_schema == args_schema
|
||||||
|
# test that the tool call schema is the same as the args schema
|
||||||
|
assert _get_tool_call_json_schema(tool) == args_schema
|
||||||
|
# test that the input schema is the same as the parent (Runnable) input schema
|
||||||
|
assert (
|
||||||
|
tool.get_input_schema().model_json_schema()
|
||||||
|
== create_model_v2(
|
||||||
|
tool.get_name("Input"),
|
||||||
|
root=tool.InputType,
|
||||||
|
module_name=tool.__class__.__module__,
|
||||||
|
).model_json_schema()
|
||||||
|
)
|
||||||
|
# test that args are extracted correctly
|
||||||
|
assert tool.args == {
|
||||||
|
"a": {"title": "A", "type": "integer"},
|
||||||
|
}
|
||||||
|
@ -127,6 +127,28 @@ def dummy_structured_tool() -> StructuredTool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
||||||
|
args_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"arg1": {"type": "integer", "description": "foo"},
|
||||||
|
"arg2": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["bar", "baz"],
|
||||||
|
"description": "one of 'bar', 'baz'",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["arg1", "arg2"],
|
||||||
|
}
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
lambda x: None,
|
||||||
|
name="dummy_function",
|
||||||
|
description="Dummy function.",
|
||||||
|
args_schema=args_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def dummy_pydantic() -> type[BaseModel]:
|
def dummy_pydantic() -> type[BaseModel]:
|
||||||
class dummy_function(BaseModel): # noqa: N801
|
class dummy_function(BaseModel): # noqa: N801
|
||||||
@ -293,6 +315,7 @@ def test_convert_to_openai_function(
|
|||||||
function: Callable,
|
function: Callable,
|
||||||
function_docstring_annotations: Callable,
|
function_docstring_annotations: Callable,
|
||||||
dummy_structured_tool: StructuredTool,
|
dummy_structured_tool: StructuredTool,
|
||||||
|
dummy_structured_tool_args_schema_dict: StructuredTool,
|
||||||
dummy_tool: BaseTool,
|
dummy_tool: BaseTool,
|
||||||
json_schema: dict,
|
json_schema: dict,
|
||||||
anthropic_tool: dict,
|
anthropic_tool: dict,
|
||||||
@ -327,6 +350,7 @@ def test_convert_to_openai_function(
|
|||||||
function,
|
function,
|
||||||
function_docstring_annotations,
|
function_docstring_annotations,
|
||||||
dummy_structured_tool,
|
dummy_structured_tool,
|
||||||
|
dummy_structured_tool_args_schema_dict,
|
||||||
dummy_tool,
|
dummy_tool,
|
||||||
json_schema,
|
json_schema,
|
||||||
anthropic_tool,
|
anthropic_tool,
|
||||||
|
Loading…
Reference in New Issue
Block a user