mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 14:03:26 +00:00
core[patch]: propagate parse_docstring
to tool decorator (#24123)
Disabled by default. ```python from langchain_core.tools import tool @tool(parse_docstring=True) def foo(bar: str, baz: int) -> str: """The foo. Args: bar: this is the bar baz: this is the baz """ return bar foo.args_schema.schema() ``` ```json { "title": "fooSchema", "description": "The foo.", "type": "object", "properties": { "bar": { "title": "Bar", "description": "this is the bar", "type": "string" }, "baz": { "title": "Baz", "description": "this is the baz", "type": "integer" } }, "required": [ "bar", "baz" ] } ```
This commit is contained in:
@@ -85,6 +85,8 @@ from langchain_core.runnables.config import (
|
||||
)
|
||||
from langchain_core.runnables.utils import accepts_context
|
||||
|
||||
FILTERED_ARGS = ("run_manager", "callbacks")
|
||||
|
||||
|
||||
class SchemaAnnotationError(TypeError):
|
||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||
@@ -149,14 +151,27 @@ def _get_filtered_args(
|
||||
}
|
||||
|
||||
|
||||
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
|
||||
def _parse_python_function_docstring(
|
||||
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
|
||||
) -> Tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
"""
|
||||
invalid_docstring_error = ValueError(
|
||||
f"Found invalid Google-Style docstring for {function}."
|
||||
)
|
||||
docstring = inspect.getdoc(function)
|
||||
if docstring:
|
||||
docstring_blocks = docstring.split("\n\n")
|
||||
if error_on_invalid_docstring:
|
||||
filtered_annotations = {
|
||||
arg for arg in annotations if arg not in (*(FILTERED_ARGS), "return")
|
||||
}
|
||||
if filtered_annotations and (
|
||||
len(docstring_blocks) < 2 or not docstring_blocks[1].startswith("Args:")
|
||||
):
|
||||
raise (invalid_docstring_error)
|
||||
descriptors = []
|
||||
args_block = None
|
||||
past_descriptors = False
|
||||
@@ -173,6 +188,8 @@ def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
|
||||
continue
|
||||
description = " ".join(descriptors)
|
||||
else:
|
||||
if error_on_invalid_docstring:
|
||||
raise (invalid_docstring_error)
|
||||
description = ""
|
||||
args_block = None
|
||||
arg_descriptions = {}
|
||||
@@ -187,20 +204,38 @@ def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _validate_docstring_args_against_annotations(
|
||||
arg_descriptions: dict, annotations: dict
|
||||
) -> None:
|
||||
"""Raise error if docstring arg is not in type annotations."""
|
||||
for docstring_arg in arg_descriptions:
|
||||
if docstring_arg not in annotations:
|
||||
raise ValueError(
|
||||
f"Arg {docstring_arg} in docstring not found in function signature."
|
||||
)
|
||||
|
||||
|
||||
def _infer_arg_descriptions(
|
||||
fn: Callable, *, parse_docstring: bool = False
|
||||
fn: Callable,
|
||||
*,
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> Tuple[str, dict]:
|
||||
"""Infer argument descriptions from a function's docstring."""
|
||||
if parse_docstring:
|
||||
description, arg_descriptions = _parse_python_function_docstring(fn)
|
||||
else:
|
||||
description = inspect.getdoc(fn) or ""
|
||||
arg_descriptions = {}
|
||||
if hasattr(inspect, "get_annotations"):
|
||||
# This is for python < 3.10
|
||||
annotations = inspect.get_annotations(fn) # type: ignore
|
||||
else:
|
||||
annotations = getattr(fn, "__annotations__", {})
|
||||
if parse_docstring:
|
||||
description, arg_descriptions = _parse_python_function_docstring(
|
||||
fn, annotations, error_on_invalid_docstring=error_on_invalid_docstring
|
||||
)
|
||||
else:
|
||||
description = inspect.getdoc(fn) or ""
|
||||
arg_descriptions = {}
|
||||
if parse_docstring:
|
||||
_validate_docstring_args_against_annotations(arg_descriptions, annotations)
|
||||
for arg, arg_type in annotations.items():
|
||||
if arg in arg_descriptions:
|
||||
continue
|
||||
@@ -222,6 +257,7 @@ def create_schema_from_function(
|
||||
*,
|
||||
filter_args: Optional[Sequence[str]] = None,
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature.
|
||||
Args:
|
||||
@@ -229,21 +265,23 @@ def create_schema_from_function(
|
||||
func: Function to generate the schema from
|
||||
filter_args: Optional list of arguments to exclude from the schema
|
||||
parse_docstring: Whether to parse the function's docstring for descriptions
|
||||
for each argument.
|
||||
for each argument.
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
||||
whether to raise ValueError on invalid Google Style docstrings.
|
||||
Returns:
|
||||
A pydantic model with the same arguments as the function
|
||||
"""
|
||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||
inferred_model = validated.model # type: ignore
|
||||
filter_args = (
|
||||
filter_args if filter_args is not None else ("run_manager", "callbacks")
|
||||
)
|
||||
filter_args = filter_args if filter_args is not None else FILTERED_ARGS
|
||||
for arg in filter_args:
|
||||
if arg in inferred_model.__fields__:
|
||||
del inferred_model.__fields__[arg]
|
||||
description, arg_descriptions = _infer_arg_descriptions(
|
||||
func, parse_docstring=parse_docstring
|
||||
func,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
)
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args)
|
||||
@@ -909,6 +947,8 @@ class StructuredTool(BaseTool):
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> StructuredTool:
|
||||
"""Create tool from a given function.
|
||||
@@ -923,6 +963,10 @@ class StructuredTool(BaseTool):
|
||||
return_direct: Whether to return the result directly or as a callback
|
||||
args_schema: The schema of the tool's input arguments
|
||||
infer_schema: Whether to infer the schema from the function's signature
|
||||
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
|
||||
to parse parameter descriptions from Google Style function docstrings.
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
||||
whether to raise ValueError on invalid Google Style docstrings.
|
||||
**kwargs: Additional arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
@@ -963,7 +1007,12 @@ class StructuredTool(BaseTool):
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
# schema name is appended within function
|
||||
_args_schema = create_schema_from_function(name, source_function)
|
||||
_args_schema = create_schema_from_function(
|
||||
name,
|
||||
source_function,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
)
|
||||
return cls(
|
||||
name=name,
|
||||
func=func,
|
||||
@@ -980,6 +1029,8 @@ def tool(
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
) -> Callable:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
@@ -991,6 +1042,10 @@ def tool(
|
||||
infer_schema: Whether to infer the schema of the arguments from
|
||||
the function's signature. This also makes the resultant tool
|
||||
accept a dictionary input to its `run()` function.
|
||||
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to
|
||||
parse parameter descriptions from Google Style function docstrings.
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
||||
whether to raise ValueError on invalid Google Style docstrings.
|
||||
|
||||
Requires:
|
||||
- Function must be of type (str) -> str
|
||||
@@ -1008,6 +1063,78 @@ def tool(
|
||||
def search_api(query: str) -> str:
|
||||
# Searches the API for the query.
|
||||
return
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
Parse Google-style docstrings:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@tool(parse_docstring=True)
|
||||
def foo(bar: str, baz: int) -> str:
|
||||
\"\"\"The foo.
|
||||
|
||||
Args:
|
||||
bar: The bar.
|
||||
baz: The baz.
|
||||
\"\"\"
|
||||
return bar
|
||||
|
||||
foo.args_schema.schema()
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"title": "fooSchema",
|
||||
"description": "The foo.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bar": {
|
||||
"title": "Bar",
|
||||
"description": "The bar.",
|
||||
"type": "string"
|
||||
},
|
||||
"baz": {
|
||||
"title": "Baz",
|
||||
"description": "The baz.",
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"bar",
|
||||
"baz"
|
||||
]
|
||||
}
|
||||
|
||||
Note that parsing by default will raise ``ValueError`` if the docstring
|
||||
is considered invalid. A docstring is considered invalid if it contains
|
||||
arguments not in the function signature, or is unable to be parsed into
|
||||
a summary and "Args:" blocks. Examples below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# No args section
|
||||
def invalid_docstring_1(bar: str, baz: int) -> str:
|
||||
\"\"\"The foo.\"\"\"
|
||||
return bar
|
||||
|
||||
# Improper whitespace between summary and args section
|
||||
def invalid_docstring_2(bar: str, baz: int) -> str:
|
||||
\"\"\"The foo.
|
||||
Args:
|
||||
bar: The bar.
|
||||
baz: The baz.
|
||||
\"\"\"
|
||||
return bar
|
||||
|
||||
# Documented args absent from function signature
|
||||
def invalid_docstring_3(bar: str, baz: int) -> str:
|
||||
\"\"\"The foo.
|
||||
|
||||
Args:
|
||||
banana: The bar.
|
||||
monkey: The baz.
|
||||
\"\"\"
|
||||
return bar
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
@@ -1052,6 +1179,8 @@ def tool(
|
||||
return_direct=return_direct,
|
||||
args_schema=schema,
|
||||
infer_schema=infer_schema,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
)
|
||||
# If someone doesn't want a schema applied, we must treat it as
|
||||
# a simple string->string function
|
||||
|
@@ -138,7 +138,11 @@ def convert_python_function_to_openai_function(
|
||||
|
||||
func_name = _get_python_function_name(function)
|
||||
model = tools.create_schema_from_function(
|
||||
func_name, function, filter_args=(), parse_docstring=True
|
||||
func_name,
|
||||
function,
|
||||
filter_args=(),
|
||||
parse_docstring=True,
|
||||
error_on_invalid_docstring=False,
|
||||
)
|
||||
return convert_pydantic_to_openai_function(
|
||||
model,
|
||||
|
@@ -959,6 +959,84 @@ def test_tool_arg_descriptions() -> None:
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
|
||||
# Test parses docstring
|
||||
foo2 = tool(foo, parse_docstring=True)
|
||||
args_schema = foo2.args_schema.schema() # type: ignore
|
||||
expected = {
|
||||
"title": "fooSchema",
|
||||
"description": "The foo.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "description": "The bar.", "type": "string"},
|
||||
"baz": {"title": "Baz", "description": "The baz.", "type": "integer"},
|
||||
},
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
assert args_schema == expected
|
||||
|
||||
# Test parsing with run_manager does not raise error
|
||||
def foo3(
|
||||
bar: str, baz: int, run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
) -> str:
|
||||
"""The foo.
|
||||
|
||||
Args:
|
||||
bar: The bar.
|
||||
baz: The baz.
|
||||
"""
|
||||
return bar
|
||||
|
||||
as_tool = tool(foo3, parse_docstring=True)
|
||||
args_schema = as_tool.args_schema.schema() # type: ignore
|
||||
assert args_schema["description"] == expected["description"]
|
||||
assert args_schema["properties"] == expected["properties"]
|
||||
|
||||
# Test parameterless tool does not raise error for missing Args section
|
||||
# in docstring.
|
||||
def foo4() -> str:
|
||||
"""The foo."""
|
||||
return "bar"
|
||||
|
||||
as_tool = tool(foo4, parse_docstring=True)
|
||||
args_schema = as_tool.args_schema.schema() # type: ignore
|
||||
assert args_schema["description"] == expected["description"]
|
||||
|
||||
def foo5(run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
|
||||
"""The foo."""
|
||||
return "bar"
|
||||
|
||||
as_tool = tool(foo5, parse_docstring=True)
|
||||
args_schema = as_tool.args_schema.schema() # type: ignore
|
||||
assert args_schema["description"] == expected["description"]
|
||||
|
||||
|
||||
def test_tool_invalid_docstrings() -> None:
|
||||
# Test invalid docstrings
|
||||
def foo3(bar: str, baz: int) -> str:
|
||||
"""The foo."""
|
||||
return bar
|
||||
|
||||
def foo4(bar: str, baz: int) -> str:
|
||||
"""The foo.
|
||||
Args:
|
||||
bar: The bar.
|
||||
baz: The baz.
|
||||
"""
|
||||
return bar
|
||||
|
||||
def foo5(bar: str, baz: int) -> str:
|
||||
"""The foo.
|
||||
|
||||
Args:
|
||||
banana: The bar.
|
||||
monkey: The baz.
|
||||
"""
|
||||
return bar
|
||||
|
||||
for func in [foo3, foo4, foo5]:
|
||||
with pytest.raises(ValueError):
|
||||
_ = tool(func, parse_docstring=True)
|
||||
|
||||
|
||||
def test_tool_annotated_descriptions() -> None:
|
||||
def foo(
|
||||
|
Reference in New Issue
Block a user