From 980c8651743b653f994ad6b97a27b0fa31ee92b4 Mon Sep 17 00:00:00 2001 From: Alejandra De Luna Date: Fri, 23 Jun 2023 04:48:27 -0400 Subject: [PATCH] fix: remove callbacks arg from Tool and StructuredTool inferred schema (#6483) Fixes #5456 This PR removes the `callbacks` argument from a tool's schema when creating a `Tool` or `StructuredTool` with the `from_function` method and `infer_schema` is set to `True`. The `callbacks` argument is now removed in the `create_schema_from_function` and `_get_filtered_args` methods. As suggested by @vowelparrot, this fix provides a straightforward solution that minimally affects the existing implementation. A test was added to verify that this change enables the expected use of `Tool` and `StructuredTool` when using a `CallbackManager` and inferring the tool's schema. - @hwchase17 --- langchain/tools/base.py | 4 +- tests/unit_tests/tools/test_base.py | 59 +++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/langchain/tools/base.py b/langchain/tools/base.py index c3fab242c92..a8d4c88863f 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -82,7 +82,7 @@ def _get_filtered_args( """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - return {k: schema[k] for k in valid_keys if k != "run_manager"} + return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} class _SchemaConfig: @@ -108,6 +108,8 @@ def create_schema_from_function( inferred_model = validated.model # type: ignore if "run_manager" in inferred_model.__fields__: del inferred_model.__fields__["run_manager"] + if "callbacks" in inferred_model.__fields__: + del inferred_model.__fields__["callbacks"] # Pydantic adds placeholder virtual fields we need to strip valid_properties = _get_filtered_args(inferred_model, func) return _create_subset_model( diff --git a/tests/unit_tests/tools/test_base.py b/tests/unit_tests/tools/test_base.py index e486fcc9d00..eadfbcf97c5 100644 --- a/tests/unit_tests/tools/test_base.py +++ b/tests/unit_tests/tools/test_base.py @@ -19,6 +19,7 @@ from langchain.tools.base import ( StructuredTool, ToolException, ) +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler def test_unnamed_decorator() -> None: @@ -393,6 +394,64 @@ def test_empty_args_decorator() -> None: assert empty_tool_input.run({}) == "the empty result" +def test_tool_from_function_with_run_manager() -> None: + """Test run of tool when using run_manager.""" + + def foo(bar: str, callbacks: Optional[CallbackManagerForToolRun] = None) -> str: + """Docstring + Args: + bar: str + """ + assert callbacks is not None + return "foo" + bar + + handler = FakeCallbackHandler() + tool = Tool.from_function(foo, name="foo", description="Docstring") + + assert tool.run(tool_input={"bar": "bar"}, run_manager=[handler]) == "foobar" + assert tool.run("baz", run_manager=[handler]) == "foobaz" + + +def test_structured_tool_from_function_with_run_manager() -> None: + """Test args and schema of structured tool when using callbacks.""" + + def foo( + bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None + ) -> str: + """Docstring + Args: + bar: int + baz: str + """ + assert callbacks is not None + return str(bar) + baz + + handler = FakeCallbackHandler() + structured_tool = StructuredTool.from_function(foo) + + assert structured_tool.args == { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "string"}, + } + + assert structured_tool.args_schema.schema() == { + "properties": { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "string"}, + }, + "title": "fooSchemaSchema", + "type": "object", + "required": ["bar", "baz"], + } + + assert ( + structured_tool.run( + tool_input={"bar": "10", "baz": "baz"}, run_manger=[handler] + ) + == "10baz" + ) + + def test_named_tool_decorator() -> None: """Test functionality when arguments are provided as input to decorator."""