diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 38119ceb705..7ad7da84f35 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -21,6 +21,7 @@ from __future__ import annotations import asyncio import inspect +import textwrap import uuid import warnings from abc import ABC, abstractmethod @@ -825,16 +826,19 @@ class StructuredTool(BaseTool): else: raise ValueError("Function and/or coroutine must be provided") name = name or source_function.__name__ - description = description or source_function.__doc__ - if description is None: + description_ = description or source_function.__doc__ + if description_ is None: raise ValueError( "Function must have a docstring if description not provided." ) + if description is None: + # Only apply if using the function's docstring + description_ = textwrap.dedent(description_).strip() # Description example: # search_api(query: str) - Searches the API for the query. sig = signature(source_function) - description = f"{name}{sig} - {description.strip()}" + description_ = f"{name}{sig} - {description_.strip()}" _args_schema = args_schema if _args_schema is None and infer_schema: # schema name is appended within function @@ -844,7 +848,7 @@ class StructuredTool(BaseTool): func=func, coroutine=coroutine, args_schema=_args_schema, # type: ignore[arg-type] - description=description, + description=description_, return_direct=return_direct, **kwargs, ) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index a424d6ecf60..f17e77fa013 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -3,6 +3,7 @@ import asyncio import json import sys +import textwrap from datetime import datetime from enum import Enum from functools import partial @@ -333,7 +334,7 @@ def test_structured_tool_from_function_docstring() -> None: prefix = "foo(bar: int, baz: str) -> str - " assert foo.__doc__ is not None - assert structured_tool.description == prefix + foo.__doc__.strip() + assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip()) def test_structured_tool_from_function_docstring_complex_args() -> None: @@ -366,7 +367,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None: prefix = "foo(bar: int, baz: List[str]) -> str - " assert foo.__doc__ is not None - assert structured_tool.description == prefix + foo.__doc__.strip() + assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__).strip() def test_structured_tool_lambda_multi_args_schema() -> None: @@ -701,7 +702,7 @@ def test_structured_tool_from_function() -> None: prefix = "foo(bar: int, baz: str) -> str - " assert foo.__doc__ is not None - assert structured_tool.description == prefix + foo.__doc__.strip() + assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip()) def test_validation_error_handling_bool() -> None: