From 8ee8ca7c83d08e982c775ecf7e27f280e3b55e91 Mon Sep 17 00:00:00 2001 From: ccurme Date: Thu, 11 Jul 2024 16:11:45 -0400 Subject: [PATCH] core[patch]: propagate `parse_docstring` to tool decorator (#24123) Disabled by default. ```python from langchain_core.tools import tool @tool(parse_docstring=True) def foo(bar: str, baz: int) -> str: """The foo. Args: bar: this is the bar baz: this is the baz """ return bar foo.args_schema.schema() ``` ```json { "title": "fooSchema", "description": "The foo.", "type": "object", "properties": { "bar": { "title": "Bar", "description": "this is the bar", "type": "string" }, "baz": { "title": "Baz", "description": "this is the baz", "type": "integer" } }, "required": [ "bar", "baz" ] } ``` --- docs/docs/how_to/custom_tools.ipynb | 227 ++++++++++++++++-- libs/core/langchain_core/tools.py | 155 +++++++++++- .../langchain_core/utils/function_calling.py | 6 +- libs/core/tests/unit_tests/test_tools.py | 78 ++++++ 4 files changed, 427 insertions(+), 39 deletions(-) diff --git a/docs/docs/how_to/custom_tools.ipynb b/docs/docs/how_to/custom_tools.ipynb index e678ded1d02..ce2043bf838 100644 --- a/docs/docs/how_to/custom_tools.ipynb +++ b/docs/docs/how_to/custom_tools.ipynb @@ -16,13 +16,15 @@ "| args_schema | Pydantic BaseModel | Optional but recommended, can be used to provide more information (e.g., few-shot examples) or validation for expected parameters |\n", "| return_direct | boolean | Only relevant for agents. When True, after invoking the given tool, the agent will stop and return the result direcly to the user. |\n", "\n", - "LangChain provides 3 ways to create tools:\n", + "LangChain supports the creation of tools from:\n", "\n", - "1. Using [@tool decorator](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html#langchain_core.tools.tool) -- the simplest way to define a custom tool.\n", - "2. Using [StructuredTool.from_function](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.StructuredTool.html#langchain_core.tools.StructuredTool.from_function) class method -- this is similar to the `@tool` decorator, but allows more configuration and specification of both sync and async implementations.\n", + "1. Functions;\n", + "2. LangChain [Runnables](/docs/concepts#runnable-interface);\n", "3. By sub-classing from [BaseTool](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html) -- This is the most flexible method, it provides the largest degree of control, at the expense of more effort and code.\n", "\n", - "The `@tool` or the `StructuredTool.from_function` class method should be sufficient for most use cases.\n", + "Creating tools from functions may be sufficient for most use cases, and can be done via a simple [@tool decorator](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html#langchain_core.tools.tool). If more configuration is needed-- e.g., specification of both sync and async implementations-- one can also use the [StructuredTool.from_function](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.StructuredTool.html#langchain_core.tools.StructuredTool.from_function) class method.\n", + "\n", + "In this guide we provide an overview of these methods.\n", "\n", ":::{.callout-tip}\n", "\n", @@ -35,7 +37,9 @@ "id": "c7326b23", "metadata": {}, "source": [ - "## @tool decorator\n", + "## Creating tools from functions\n", + "\n", + "### @tool decorator\n", "\n", "This `@tool` decorator is the simplest way to define a custom tool. The decorator uses the function name as the tool name by default, but this can be overridden by passing a string as the first argument. Additionally, the decorator will use the function's docstring as the tool's description - so a docstring MUST be provided. " ] @@ -51,7 +55,7 @@ "output_type": "stream", "text": [ "multiply\n", - "multiply(a: int, b: int) -> int - Multiply two numbers.\n", + "Multiply two numbers.\n", "{'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'integer'}}\n" ] } @@ -96,6 +100,57 @@ " return a * b" ] }, + { + "cell_type": "markdown", + "id": "8f0edc51-c586-414c-8941-c8abe779943f", + "metadata": {}, + "source": [ + "Note that `@tool` supports parsing of annotations, nested schemas, and other features:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5626423f-053e-4a66-adca-1d794d835397", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'title': 'multiply_by_maxSchema',\n", + " 'description': 'Multiply a by the maximum of b.',\n", + " 'type': 'object',\n", + " 'properties': {'a': {'title': 'A',\n", + " 'description': 'scale factor',\n", + " 'type': 'string'},\n", + " 'b': {'title': 'B',\n", + " 'description': 'list of ints over which to take maximum',\n", + " 'type': 'array',\n", + " 'items': {'type': 'integer'}}},\n", + " 'required': ['a', 'b']}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from typing import Annotated, List\n", + "\n", + "\n", + "@tool\n", + "def multiply_by_max(\n", + " a: Annotated[str, \"scale factor\"],\n", + " b: Annotated[List[int], \"list of ints over which to take maximum\"],\n", + ") -> int:\n", + " \"\"\"Multiply a by the maximum of b.\"\"\"\n", + " return a * max(b)\n", + "\n", + "\n", + "multiply_by_max.args_schema.schema()" + ] + }, { "cell_type": "markdown", "id": "98d6eee9", @@ -106,7 +161,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "9216d03a-f6ea-4216-b7e1-0661823a4c0b", "metadata": {}, "outputs": [ @@ -115,7 +170,7 @@ "output_type": "stream", "text": [ "multiplication-tool\n", - "multiplication-tool(a: int, b: int) -> int - Multiply two numbers.\n", + "Multiply two numbers.\n", "{'a': {'title': 'A', 'description': 'first number', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'second number', 'type': 'integer'}}\n", "True\n" ] @@ -143,19 +198,84 @@ "print(multiply.return_direct)" ] }, + { + "cell_type": "markdown", + "id": "33a9e94d-0b60-48f3-a4c2-247dce096e66", + "metadata": {}, + "source": [ + "#### Docstring parsing" + ] + }, + { + "cell_type": "markdown", + "id": "6d0cb586-93d4-4ff1-9779-71df7853cb68", + "metadata": {}, + "source": [ + "`@tool` can optionally parse [Google Style docstrings](https://google.github.io/styleguide/pyguide.html#383-functions-and-methods) and associate the docstring components (such as arg descriptions) to the relevant parts of the tool schema. To toggle this behavior, specify `parse_docstring`:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "336f5538-956e-47d5-9bde-b732559f9e61", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'title': 'fooSchema',\n", + " 'description': 'The foo.',\n", + " 'type': 'object',\n", + " 'properties': {'bar': {'title': 'Bar',\n", + " 'description': 'The bar.',\n", + " 'type': 'string'},\n", + " 'baz': {'title': 'Baz', 'description': 'The baz.', 'type': 'integer'}},\n", + " 'required': ['bar', 'baz']}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@tool(parse_docstring=True)\n", + "def foo(bar: str, baz: int) -> str:\n", + " \"\"\"The foo.\n", + "\n", + " Args:\n", + " bar: The bar.\n", + " baz: The baz.\n", + " \"\"\"\n", + " return bar\n", + "\n", + "\n", + "foo.args_schema.schema()" + ] + }, + { + "cell_type": "markdown", + "id": "f18a2503-5393-421b-99fa-4a01dd824d0e", + "metadata": {}, + "source": [ + ":::{.callout-caution}\n", + "By default, `@tool(parse_docstring=True)` will raise `ValueError` if the docstring does not parse correctly. See [API Reference](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html) for detail and examples.\n", + ":::" + ] + }, { "cell_type": "markdown", "id": "b63fcc3b", "metadata": {}, "source": [ - "## StructuredTool\n", + "### StructuredTool\n", "\n", "The `StrurcturedTool.from_function` class method provides a bit more configurability than the `@tool` decorator, without requiring much additional code." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "id": "564fbe6f-11df-402d-b135-ef6ff25e1e63", "metadata": {}, "outputs": [ @@ -198,7 +318,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "id": "6bc055d4-1fbe-4db5-8881-9c382eba6b1b", "metadata": {}, "outputs": [ @@ -208,7 +328,7 @@ "text": [ "6\n", "Calculator\n", - "Calculator(a: int, b: int) -> int - multiply numbers\n", + "multiply numbers\n", "{'a': {'title': 'A', 'description': 'first number', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'second number', 'type': 'integer'}}\n" ] } @@ -239,6 +359,63 @@ "print(calculator.args)" ] }, + { + "cell_type": "markdown", + "id": "5517995d-54e3-449b-8fdb-03561f5e4647", + "metadata": {}, + "source": [ + "## Creating tools from Runnables\n", + "\n", + "LangChain [Runnables](/docs/concepts#runnable-interface) that accept string or `dict` input can be converted to tools using the [as_tool](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.Runnable.html#langchain_core.runnables.base.Runnable.as_tool) method, which allows for the specification of names, descriptions, and additional schema information for arguments.\n", + "\n", + "Example usage:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8ef593c5-cf72-4c10-bfc9-7d21874a0c24", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'answer_style': {'title': 'Answer Style', 'type': 'string'}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.language_models import GenericFakeChatModel\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [(\"human\", \"Hello. Please respond in the style of {answer_style}.\")]\n", + ")\n", + "\n", + "# Placeholder LLM\n", + "llm = GenericFakeChatModel(messages=iter([\"hello matey\"]))\n", + "\n", + "chain = prompt | llm | StrOutputParser()\n", + "\n", + "as_tool = chain.as_tool(\n", + " name=\"Style responder\", description=\"Description of when to use tool.\"\n", + ")\n", + "as_tool.args" + ] + }, + { + "cell_type": "markdown", + "id": "0521b787-a146-45a6-8ace-ae1ac4669dd7", + "metadata": {}, + "source": [ + "See [this guide](/docs/how_to/convert_runnable_to_tool) for more detail." + ] + }, { "cell_type": "markdown", "id": "b840074b-9c10-4ca0-aed8-626c52b2398f", @@ -251,7 +428,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 10, "id": "1dad8f8e", "metadata": {}, "outputs": [], @@ -300,7 +477,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, "id": "bb551c33", "metadata": {}, "outputs": [ @@ -351,7 +528,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "id": "6615cb77-fd4c-4676-8965-f92cc71d4944", "metadata": {}, "outputs": [ @@ -383,7 +560,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 13, "id": "bb2af583-eadd-41f4-a645-bf8748bd3dcd", "metadata": {}, "outputs": [ @@ -428,7 +605,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "id": "4ad0932c-8610-4278-8c57-f9218f654c8a", "metadata": {}, "outputs": [ @@ -473,7 +650,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "id": "7094c0e8-6192-4870-a942-aad5b5ae48fd", "metadata": {}, "outputs": [], @@ -496,7 +673,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "id": "b4d22022-b105-4ccc-a15b-412cb9ea3097", "metadata": {}, "outputs": [ @@ -506,7 +683,7 @@ "'Error: There is no city by the name of foobar.'" ] }, - "execution_count": 12, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -530,7 +707,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 17, "id": "3fad1728-d367-4e1b-9b54-3172981271cf", "metadata": {}, "outputs": [ @@ -540,7 +717,7 @@ "\"There is no such city, but it's probably above 0K there!\"" ] }, - "execution_count": 13, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -564,7 +741,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 18, "id": "ebfe7c1f-318d-4e58-99e1-f31e69473c46", "metadata": {}, "outputs": [ @@ -574,7 +751,7 @@ "'The following errors occurred during tool execution: `Error: There is no city by the name of foobar.`'" ] }, - "execution_count": 14, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -609,7 +786,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.4" }, "vscode": { "interpreter": { diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 49ba78f469a..9c5aaa48e65 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -85,6 +85,8 @@ from langchain_core.runnables.config import ( ) from langchain_core.runnables.utils import accepts_context +FILTERED_ARGS = ("run_manager", "callbacks") + class SchemaAnnotationError(TypeError): """Raised when 'args_schema' is missing or has an incorrect type annotation.""" @@ -149,14 +151,27 @@ def _get_filtered_args( } -def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]: +def _parse_python_function_docstring( + function: Callable, annotations: dict, error_on_invalid_docstring: bool = False +) -> Tuple[str, dict]: """Parse the function and argument descriptions from the docstring of a function. Assumes the function docstring follows Google Python style guide. """ + invalid_docstring_error = ValueError( + f"Found invalid Google-Style docstring for {function}." + ) docstring = inspect.getdoc(function) if docstring: docstring_blocks = docstring.split("\n\n") + if error_on_invalid_docstring: + filtered_annotations = { + arg for arg in annotations if arg not in (*(FILTERED_ARGS), "return") + } + if filtered_annotations and ( + len(docstring_blocks) < 2 or not docstring_blocks[1].startswith("Args:") + ): + raise (invalid_docstring_error) descriptors = [] args_block = None past_descriptors = False @@ -173,6 +188,8 @@ def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]: continue description = " ".join(descriptors) else: + if error_on_invalid_docstring: + raise (invalid_docstring_error) description = "" args_block = None arg_descriptions = {} @@ -187,20 +204,38 @@ def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]: return description, arg_descriptions +def _validate_docstring_args_against_annotations( + arg_descriptions: dict, annotations: dict +) -> None: + """Raise error if docstring arg is not in type annotations.""" + for docstring_arg in arg_descriptions: + if docstring_arg not in annotations: + raise ValueError( + f"Arg {docstring_arg} in docstring not found in function signature." + ) + + def _infer_arg_descriptions( - fn: Callable, *, parse_docstring: bool = False + fn: Callable, + *, + parse_docstring: bool = False, + error_on_invalid_docstring: bool = False, ) -> Tuple[str, dict]: """Infer argument descriptions from a function's docstring.""" - if parse_docstring: - description, arg_descriptions = _parse_python_function_docstring(fn) - else: - description = inspect.getdoc(fn) or "" - arg_descriptions = {} if hasattr(inspect, "get_annotations"): # This is for python < 3.10 annotations = inspect.get_annotations(fn) # type: ignore else: annotations = getattr(fn, "__annotations__", {}) + if parse_docstring: + description, arg_descriptions = _parse_python_function_docstring( + fn, annotations, error_on_invalid_docstring=error_on_invalid_docstring + ) + else: + description = inspect.getdoc(fn) or "" + arg_descriptions = {} + if parse_docstring: + _validate_docstring_args_against_annotations(arg_descriptions, annotations) for arg, arg_type in annotations.items(): if arg in arg_descriptions: continue @@ -222,6 +257,7 @@ def create_schema_from_function( *, filter_args: Optional[Sequence[str]] = None, parse_docstring: bool = False, + error_on_invalid_docstring: bool = False, ) -> Type[BaseModel]: """Create a pydantic schema from a function's signature. Args: @@ -229,21 +265,23 @@ def create_schema_from_function( func: Function to generate the schema from filter_args: Optional list of arguments to exclude from the schema parse_docstring: Whether to parse the function's docstring for descriptions - for each argument. + for each argument. + error_on_invalid_docstring: if ``parse_docstring`` is provided, configures + whether to raise ValueError on invalid Google Style docstrings. Returns: A pydantic model with the same arguments as the function """ # https://docs.pydantic.dev/latest/usage/validation_decorator/ validated = validate_arguments(func, config=_SchemaConfig) # type: ignore inferred_model = validated.model # type: ignore - filter_args = ( - filter_args if filter_args is not None else ("run_manager", "callbacks") - ) + filter_args = filter_args if filter_args is not None else FILTERED_ARGS for arg in filter_args: if arg in inferred_model.__fields__: del inferred_model.__fields__[arg] description, arg_descriptions = _infer_arg_descriptions( - func, parse_docstring=parse_docstring + func, + parse_docstring=parse_docstring, + error_on_invalid_docstring=error_on_invalid_docstring, ) # Pydantic adds placeholder virtual fields we need to strip valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args) @@ -909,6 +947,8 @@ class StructuredTool(BaseTool): return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, infer_schema: bool = True, + parse_docstring: bool = False, + error_on_invalid_docstring: bool = False, **kwargs: Any, ) -> StructuredTool: """Create tool from a given function. @@ -923,6 +963,10 @@ class StructuredTool(BaseTool): return_direct: Whether to return the result directly or as a callback args_schema: The schema of the tool's input arguments infer_schema: Whether to infer the schema from the function's signature + parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt + to parse parameter descriptions from Google Style function docstrings. + error_on_invalid_docstring: if ``parse_docstring`` is provided, configures + whether to raise ValueError on invalid Google Style docstrings. **kwargs: Additional arguments to pass to the tool Returns: @@ -963,7 +1007,12 @@ class StructuredTool(BaseTool): _args_schema = args_schema if _args_schema is None and infer_schema: # schema name is appended within function - _args_schema = create_schema_from_function(name, source_function) + _args_schema = create_schema_from_function( + name, + source_function, + parse_docstring=parse_docstring, + error_on_invalid_docstring=error_on_invalid_docstring, + ) return cls( name=name, func=func, @@ -980,6 +1029,8 @@ def tool( return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, infer_schema: bool = True, + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, ) -> Callable: """Make tools out of functions, can be used with or without arguments. @@ -991,6 +1042,10 @@ def tool( infer_schema: Whether to infer the schema of the arguments from the function's signature. This also makes the resultant tool accept a dictionary input to its `run()` function. + parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to + parse parameter descriptions from Google Style function docstrings. + error_on_invalid_docstring: if ``parse_docstring`` is provided, configures + whether to raise ValueError on invalid Google Style docstrings. Requires: - Function must be of type (str) -> str @@ -1008,6 +1063,78 @@ def tool( def search_api(query: str) -> str: # Searches the API for the query. return + + .. versionadded:: 0.2.14 + Parse Google-style docstrings: + + .. code-block:: python + + @tool(parse_docstring=True) + def foo(bar: str, baz: int) -> str: + \"\"\"The foo. + + Args: + bar: The bar. + baz: The baz. + \"\"\" + return bar + + foo.args_schema.schema() + + .. code-block:: python + + { + "title": "fooSchema", + "description": "The foo.", + "type": "object", + "properties": { + "bar": { + "title": "Bar", + "description": "The bar.", + "type": "string" + }, + "baz": { + "title": "Baz", + "description": "The baz.", + "type": "integer" + } + }, + "required": [ + "bar", + "baz" + ] + } + + Note that parsing by default will raise ``ValueError`` if the docstring + is considered invalid. A docstring is considered invalid if it contains + arguments not in the function signature, or is unable to be parsed into + a summary and "Args:" blocks. Examples below: + + .. code-block:: python + + # No args section + def invalid_docstring_1(bar: str, baz: int) -> str: + \"\"\"The foo.\"\"\" + return bar + + # Improper whitespace between summary and args section + def invalid_docstring_2(bar: str, baz: int) -> str: + \"\"\"The foo. + Args: + bar: The bar. + baz: The baz. + \"\"\" + return bar + + # Documented args absent from function signature + def invalid_docstring_3(bar: str, baz: int) -> str: + \"\"\"The foo. + + Args: + banana: The bar. + monkey: The baz. + \"\"\" + return bar """ def _make_with_name(tool_name: str) -> Callable: @@ -1052,6 +1179,8 @@ def tool( return_direct=return_direct, args_schema=schema, infer_schema=infer_schema, + parse_docstring=parse_docstring, + error_on_invalid_docstring=error_on_invalid_docstring, ) # If someone doesn't want a schema applied, we must treat it as # a simple string->string function diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index f4bcba6e701..280db0d3ab6 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -138,7 +138,11 @@ def convert_python_function_to_openai_function( func_name = _get_python_function_name(function) model = tools.create_schema_from_function( - func_name, function, filter_args=(), parse_docstring=True + func_name, + function, + filter_args=(), + parse_docstring=True, + error_on_invalid_docstring=False, ) return convert_pydantic_to_openai_function( model, diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 7760e90f788..61266838ecc 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -959,6 +959,84 @@ def test_tool_arg_descriptions() -> None: "required": ["bar", "baz"], } + # Test parses docstring + foo2 = tool(foo, parse_docstring=True) + args_schema = foo2.args_schema.schema() # type: ignore + expected = { + "title": "fooSchema", + "description": "The foo.", + "type": "object", + "properties": { + "bar": {"title": "Bar", "description": "The bar.", "type": "string"}, + "baz": {"title": "Baz", "description": "The baz.", "type": "integer"}, + }, + "required": ["bar", "baz"], + } + assert args_schema == expected + + # Test parsing with run_manager does not raise error + def foo3( + bar: str, baz: int, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: + """The foo. + + Args: + bar: The bar. + baz: The baz. + """ + return bar + + as_tool = tool(foo3, parse_docstring=True) + args_schema = as_tool.args_schema.schema() # type: ignore + assert args_schema["description"] == expected["description"] + assert args_schema["properties"] == expected["properties"] + + # Test parameterless tool does not raise error for missing Args section + # in docstring. + def foo4() -> str: + """The foo.""" + return "bar" + + as_tool = tool(foo4, parse_docstring=True) + args_schema = as_tool.args_schema.schema() # type: ignore + assert args_schema["description"] == expected["description"] + + def foo5(run_manager: Optional[CallbackManagerForToolRun] = None) -> str: + """The foo.""" + return "bar" + + as_tool = tool(foo5, parse_docstring=True) + args_schema = as_tool.args_schema.schema() # type: ignore + assert args_schema["description"] == expected["description"] + + +def test_tool_invalid_docstrings() -> None: + # Test invalid docstrings + def foo3(bar: str, baz: int) -> str: + """The foo.""" + return bar + + def foo4(bar: str, baz: int) -> str: + """The foo. + Args: + bar: The bar. + baz: The baz. + """ + return bar + + def foo5(bar: str, baz: int) -> str: + """The foo. + + Args: + banana: The bar. + monkey: The baz. + """ + return bar + + for func in [foo3, foo4, foo5]: + with pytest.raises(ValueError): + _ = tool(func, parse_docstring=True) + def test_tool_annotated_descriptions() -> None: def foo(