From 6ed0aa323935c2eeb20a3aa73c7cf287236a55f7 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 16 May 2024 11:17:53 -0400 Subject: [PATCH] core[major]: only use function description (#21622) Do not prefix function signature --- * Reason for this is that information is already present with tool calling models. * This will save on tokens for those models, and makes it more obvious what the description is! * The @tool can get more parameters to allow a user to re-introduce the the signature if we want --- libs/core/langchain_core/tools.py | 21 +++++++++++++++---- libs/core/tests/unit_tests/test_tools.py | 21 +++++++++++++------ .../tests/unit_tests/agents/test_mrkl.py | 2 +- .../tests/unit_tests/tools/test_render.py | 8 +++---- 4 files changed, 37 insertions(+), 15 deletions(-) diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 7ad7da84f35..b1321042ec8 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -837,8 +837,7 @@ class StructuredTool(BaseTool): # Description example: # search_api(query: str) - Searches the API for the query. - sig = signature(source_function) - description_ = f"{name}{sig} - {description_.strip()}" + description_ = f"{description_.strip()}" _args_schema = args_schema if _args_schema is None and infer_schema: # schema name is appended within function @@ -1057,7 +1056,16 @@ def render_text_description(tools: List[BaseTool]) -> str: search: This tool is used for search calculator: This tool is used for math """ - return "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) + descriptions = [] + for tool in tools: + if hasattr(tool, "func") and tool.func: + sig = signature(tool.func) + description = f"{tool.name}{sig} - {tool.description}" + else: + description = f"{tool.name} - {tool.description}" + + descriptions.append(description) + return "\n".join(descriptions) def render_text_description_and_args(tools: List[BaseTool]) -> str: @@ -1074,7 +1082,12 @@ args: {"expression": {"type": "string"}} tool_strings = [] for tool in tools: args_schema = str(tool.args) - tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") + if hasattr(tool, "func") and tool.func: + sig = signature(tool.func) + description = f"{tool.name}{sig} - {tool.description}" + else: + description = f"{tool.name} - {tool.description}" + tool_strings.append(f"{description}, args: {args_schema}") return "\n".join(tool_strings) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index f17e77fa013..78a44417431 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -332,9 +332,8 @@ def test_structured_tool_from_function_docstring() -> None: "required": ["bar", "baz"], } - prefix = "foo(bar: int, baz: str) -> str - " assert foo.__doc__ is not None - assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip()) + assert structured_tool.description == textwrap.dedent(foo.__doc__.strip()) def test_structured_tool_from_function_docstring_complex_args() -> None: @@ -365,9 +364,8 @@ def test_structured_tool_from_function_docstring_complex_args() -> None: "required": ["bar", "baz"], } - prefix = "foo(bar: int, baz: List[str]) -> str - " assert foo.__doc__ is not None - assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__).strip() + assert structured_tool.description == textwrap.dedent(foo.__doc__).strip() def test_structured_tool_lambda_multi_args_schema() -> None: @@ -700,9 +698,8 @@ def test_structured_tool_from_function() -> None: "required": ["bar", "baz"], } - prefix = "foo(bar: int, baz: str) -> str - " assert foo.__doc__ is not None - assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip()) + assert structured_tool.description == textwrap.dedent(foo.__doc__.strip()) def test_validation_error_handling_bool() -> None: @@ -906,3 +903,15 @@ async def test_async_tool_pass_context() -> None: assert ( await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore ) + + +def test_tool_description() -> None: + def foo(bar: str) -> str: + """The foo.""" + return bar + + foo1 = tool(foo) + assert foo1.description == "The foo." # type: ignore + + foo2 = StructuredTool.from_function(foo) + assert foo2.description == "The foo." diff --git a/libs/langchain/tests/unit_tests/agents/test_mrkl.py b/libs/langchain/tests/unit_tests/agents/test_mrkl.py index f1df200357c..c05dbcc80d1 100644 --- a/libs/langchain/tests/unit_tests/agents/test_mrkl.py +++ b/libs/langchain/tests/unit_tests/agents/test_mrkl.py @@ -152,7 +152,7 @@ def test_from_chains() -> None: Tool(name="bar", func=lambda x: "bar", description="foobar2"), ] agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs) - expected_tools_prompt = "foo: foobar1\nbar: foobar2" + expected_tools_prompt = "foo(x) - foobar1\nbar(x) - foobar2" expected_tool_names = "foo, bar" expected_template = "\n\n".join( [ diff --git a/libs/langchain/tests/unit_tests/tools/test_render.py b/libs/langchain/tests/unit_tests/tools/test_render.py index e7bea150951..c1cd56c9b07 100644 --- a/libs/langchain/tests/unit_tests/tools/test_render.py +++ b/libs/langchain/tests/unit_tests/tools/test_render.py @@ -28,15 +28,15 @@ def tools() -> List[BaseTool]: def test_render_text_description(tools: List[BaseTool]) -> None: tool_string = render_text_description(tools) - expected_string = """search: search(query: str) -> str - Lookup things online. -calculator: calculator(expression: str) -> str - Do math.""" + expected_string = """search(query: str) -> str - Lookup things online. +calculator(expression: str) -> str - Do math.""" assert tool_string == expected_string def test_render_text_description_and_args(tools: List[BaseTool]) -> None: tool_string = render_text_description_and_args(tools) - expected_string = """search: search(query: str) -> str - Lookup things online., \ + expected_string = """search(query: str) -> str - Lookup things online., \ args: {'query': {'title': 'Query', 'type': 'string'}} -calculator: calculator(expression: str) -> str - Do math., \ +calculator(expression: str) -> str - Do math., \ args: {'expression': {'title': 'Expression', 'type': 'string'}}""" assert tool_string == expected_string