[Patch] Dedent docstring (#20959)

Technically a slight prompt breaking change, but I think positive EV in
that it saves tokens and results in more sane / in-distribution prompts
This commit is contained in:
William FH 2024-04-30 07:40:57 -07:00 committed by GitHub
parent 845d8e0025
commit 5c63ac3dd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 7 deletions

View File

@ -21,6 +21,7 @@ from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
import textwrap
import uuid import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -825,16 +826,19 @@ class StructuredTool(BaseTool):
else: else:
raise ValueError("Function and/or coroutine must be provided") raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__ name = name or source_function.__name__
description = description or source_function.__doc__ description_ = description or source_function.__doc__
if description is None: if description_ is None:
raise ValueError( raise ValueError(
"Function must have a docstring if description not provided." "Function must have a docstring if description not provided."
) )
if description is None:
# Only apply if using the function's docstring
description_ = textwrap.dedent(description_).strip()
# Description example: # Description example:
# search_api(query: str) - Searches the API for the query. # search_api(query: str) - Searches the API for the query.
sig = signature(source_function) sig = signature(source_function)
description = f"{name}{sig} - {description.strip()}" description_ = f"{name}{sig} - {description_.strip()}"
_args_schema = args_schema _args_schema = args_schema
if _args_schema is None and infer_schema: if _args_schema is None and infer_schema:
# schema name is appended within function # schema name is appended within function
@ -844,7 +848,7 @@ class StructuredTool(BaseTool):
func=func, func=func,
coroutine=coroutine, coroutine=coroutine,
args_schema=_args_schema, # type: ignore[arg-type] args_schema=_args_schema, # type: ignore[arg-type]
description=description, description=description_,
return_direct=return_direct, return_direct=return_direct,
**kwargs, **kwargs,
) )

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import json import json
import sys import sys
import textwrap
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from functools import partial from functools import partial
@ -333,7 +334,7 @@ def test_structured_tool_from_function_docstring() -> None:
prefix = "foo(bar: int, baz: str) -> str - " prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip() assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
def test_structured_tool_from_function_docstring_complex_args() -> None: def test_structured_tool_from_function_docstring_complex_args() -> None:
@ -366,7 +367,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
prefix = "foo(bar: int, baz: List[str]) -> str - " prefix = "foo(bar: int, baz: List[str]) -> str - "
assert foo.__doc__ is not None assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip() assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__).strip()
def test_structured_tool_lambda_multi_args_schema() -> None: def test_structured_tool_lambda_multi_args_schema() -> None:
@ -701,7 +702,7 @@ def test_structured_tool_from_function() -> None:
prefix = "foo(bar: int, baz: str) -> str - " prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip() assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
def test_validation_error_handling_bool() -> None: def test_validation_error_handling_bool() -> None: