Compare commits

...

3 Commits

Author SHA1 Message Date
Bagatur
49b68a0cb2 core[minor]: update @tool inferred schema and description 2024-05-14 23:26:58 -07:00
Bagatur
8716d065a3 wip 2024-05-14 17:57:02 -07:00
Bagatur
c15a541ccd wip 2024-05-14 17:52:13 -07:00
3 changed files with 309 additions and 61 deletions

View File

@@ -21,7 +21,7 @@ from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
import textwrap import re
import uuid import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -37,8 +37,6 @@ from langchain_core.callbacks import (
BaseCallbackManager, BaseCallbackManager,
CallbackManager, CallbackManager,
CallbackManagerForToolRun, CallbackManagerForToolRun,
)
from langchain_core.callbacks.manager import (
Callbacks, Callbacks,
) )
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
@@ -77,7 +75,12 @@ class SchemaAnnotationError(TypeError):
def _create_subset_model( def _create_subset_model(
name: str, model: Type[BaseModel], field_names: list name: str,
model: Type[BaseModel],
field_names: list,
*,
description: Optional[str] = None,
field_descriptions: Optional[Dict[str, str]] = None,
) -> Type[BaseModel]: ) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields.""" """Create a pydantic model with only a subset of model's fields."""
fields = {} fields = {}
@@ -89,21 +92,70 @@ def _create_subset_model(
if field.required and not field.allow_none if field.required and not field.allow_none
else Optional[field.outer_type_] else Optional[field.outer_type_]
) )
if field_descriptions and field_name in field_descriptions:
field.field_info.description = field_descriptions[field_name]
fields[field_name] = (t, field.field_info) fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore rtn = create_model(name, __doc__=description, **fields) # type: ignore
return rtn return rtn
def _get_filtered_args( def _get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
"""Get the arguments from a function's signature.""" """Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"] schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
def _parse_args_from_docstring(docstring: Optional[str]) -> Dict[str, str]:
"""Parses the argument descriptions from a Google-style docstring.
Args:
docstring: The docstring to parse.
Returns:
dict: A dictionary where keys are argument names and values are their
descriptions.
"""
args_dict: Dict[str, str] = {}
if docstring and (args_section := re.search(r"Args:\n((?:\s*.+\n)+)", docstring)):
arg_lines = args_section.group(1).strip().split("\n")
else:
return args_dict
current_arg = None
current_desc = []
for line in arg_lines:
if match := re.match(r"\s*(\w+).*?:\s*(.*)", line):
if current_arg:
args_dict[current_arg] = " ".join(current_desc).strip()
current_arg = match.group(1)
if current_arg in ("Returns", "Yields", "Raises"):
current_arg = None
break
current_desc = [match.group(2).strip()]
else:
current_desc.append(line.strip())
if current_arg:
args_dict[current_arg] = " ".join(current_desc).strip()
return args_dict
def _parse_func_description_from_docstring(docstring: Optional[str]) -> Optional[str]:
if not docstring:
return docstring
if description_match := re.search(
r"(.*?)(?:Args|Returns|Yields|Raises)", docstring, flags=re.DOTALL
):
description = description_match.group(1)
else:
description = docstring
return " ".join(li.strip() for li in description.split("\n") if li.strip())
class _SchemaConfig: class _SchemaConfig:
"""Configuration for the pydantic model.""" """Configuration for the pydantic model."""
@@ -111,14 +163,13 @@ class _SchemaConfig:
arbitrary_types_allowed: bool = True arbitrary_types_allowed: bool = True
def create_schema_from_function( def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature. """Create a pydantic schema from a function's signature.
Args: Args:
model_name: Name to assign to the generated pydandic schema model_name: Name to assign to the generated pydantdic schema
func: Function to generate the schema from func: Function to generate the schema from
Returns: Returns:
A pydantic model with the same arguments as the function A pydantic model with the same arguments as the function
""" """
@@ -131,8 +182,15 @@ def create_schema_from_function(
del inferred_model.__fields__["callbacks"] del inferred_model.__fields__["callbacks"]
# Pydantic adds placeholder virtual fields we need to strip # Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func) valid_properties = _get_filtered_args(inferred_model, func)
docstring = getattr(func, "__doc__", "")
func_description = _parse_func_description_from_docstring(docstring)
arg_descriptions = _parse_args_from_docstring(docstring)
return _create_subset_model( return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties) f"{model_name}Schema",
inferred_model,
list(valid_properties),
description=func_description,
field_descriptions=arg_descriptions,
) )
@@ -802,7 +860,7 @@ class StructuredTool(BaseTool):
description: The description of the tool. Defaults to the function docstring description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the result directly or as a callback return_direct: Whether to return the result directly or as a callback
args_schema: The schema of the tool's input arguments args_schema: The schema of the tool's input arguments
infer_schema: Whether to infer the schema from the function's signature infer_schema: DEPRECATED. args_schema is always inferred if not specified.
**kwargs: Additional arguments to pass to the tool **kwargs: Additional arguments to pass to the tool
Returns: Returns:
@@ -826,29 +884,24 @@ class StructuredTool(BaseTool):
else: else:
raise ValueError("Function and/or coroutine must be provided") raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__ name = name or source_function.__name__
description_ = description or source_function.__doc__ inferred_schema = create_schema_from_function(name, source_function)
if description_ is None: args_schema = args_schema or inferred_schema
raise ValueError( description = (
"Function must have a docstring if description not provided." description
) or args_schema.schema().get("description")
or inferred_schema.schema().get("description")
)
if description is None: if description is None:
# Only apply if using the function's docstring raise ValueError(
description_ = textwrap.dedent(description_).strip() "Must specify a description or pass in an args_schema or function with "
"a docstring."
# Description example: )
# search_api(query: str) - Searches the API for the query.
sig = signature(source_function)
description_ = f"{name}{sig} - {description_.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
# schema name is appended within function
_args_schema = create_schema_from_function(name, source_function)
return cls( return cls(
name=name, name=name,
func=func, func=func,
coroutine=coroutine, coroutine=coroutine,
args_schema=_args_schema, # type: ignore[arg-type] args_schema=args_schema,
description=description_, description=description,
return_direct=return_direct, return_direct=return_direct,
**kwargs, **kwargs,
) )

View File

@@ -3,7 +3,6 @@
import asyncio import asyncio
import json import json
import sys import sys
import textwrap
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from functools import partial from functools import partial
@@ -24,6 +23,9 @@ from langchain_core.tools import (
Tool, Tool,
ToolException, ToolException,
_create_subset_model, _create_subset_model,
_parse_args_from_docstring,
_parse_func_description_from_docstring,
create_schema_from_function,
tool, tool,
) )
from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.fake.callbacks import FakeCallbackHandler
@@ -318,23 +320,22 @@ def test_structured_tool_from_function_docstring() -> None:
structured_tool = StructuredTool.from_function(foo) structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo" assert structured_tool.name == "foo"
assert structured_tool.args == { assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string", "description": "str"},
} }
assert structured_tool.args_schema.schema() == { assert structured_tool.args_schema.schema() == {
"properties": { "properties": {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string", "description": "str"},
}, },
"title": "fooSchema", "title": "fooSchema",
"description": "Docstring",
"type": "object", "type": "object",
"required": ["bar", "baz"], "required": ["bar", "baz"],
} }
prefix = "foo(bar: int, baz: str) -> str - " assert structured_tool.description == "Docstring"
assert foo.__doc__ is not None
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
def test_structured_tool_from_function_docstring_complex_args() -> None: def test_structured_tool_from_function_docstring_complex_args() -> None:
@@ -351,23 +352,32 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
structured_tool = StructuredTool.from_function(foo) structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo" assert structured_tool.name == "foo"
assert structured_tool.args == { assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, "baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
"description": "List[str]",
},
} }
assert structured_tool.args_schema.schema() == { assert structured_tool.args_schema.schema() == {
"properties": { "properties": {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, "baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
"description": "List[str]",
},
}, },
"title": "fooSchema", "title": "fooSchema",
"description": "Docstring",
"type": "object", "type": "object",
"required": ["bar", "baz"], "required": ["bar", "baz"],
} }
prefix = "foo(bar: int, baz: List[str]) -> str - " assert structured_tool.description == "Docstring"
assert foo.__doc__ is not None
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__).strip()
def test_structured_tool_lambda_multi_args_schema() -> None: def test_structured_tool_lambda_multi_args_schema() -> None:
@@ -451,16 +461,17 @@ def test_structured_tool_from_function_with_run_manager() -> None:
structured_tool = StructuredTool.from_function(foo) structured_tool = StructuredTool.from_function(foo)
assert structured_tool.args == { assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string", "description": "str"},
} }
assert structured_tool.args_schema.schema() == { assert structured_tool.args_schema.schema() == {
"properties": { "properties": {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string", "description": "str"},
}, },
"title": "fooSchema", "title": "fooSchema",
"description": "Docstring",
"type": "object", "type": "object",
"required": ["bar", "baz"], "required": ["bar", "baz"],
} }
@@ -553,7 +564,7 @@ def test_tool_with_kwargs() -> None:
def test_missing_docstring() -> None: def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing.""" """Test error is raised when docstring is missing."""
# expect to throw a value error if there's no docstring # expect to throw a value error if there's no docstring
with pytest.raises(ValueError, match="Function must have a docstring"): with pytest.raises(ValueError):
@tool @tool
def search_api(query: str) -> str: def search_api(query: str) -> str:
@@ -686,23 +697,22 @@ def test_structured_tool_from_function() -> None:
structured_tool = StructuredTool.from_function(foo) structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo" assert structured_tool.name == "foo"
assert structured_tool.args == { assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string", "description": "str"},
} }
assert structured_tool.args_schema.schema() == { assert structured_tool.args_schema.schema() == {
"title": "fooSchema", "title": "fooSchema",
"description": "Docstring",
"type": "object", "type": "object",
"properties": { "properties": {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string", "description": "str"},
}, },
"required": ["bar", "baz"], "required": ["bar", "baz"],
} }
prefix = "foo(bar: int, baz: str) -> str - " assert structured_tool.description == "Docstring"
assert foo.__doc__ is not None
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
def test_validation_error_handling_bool() -> None: def test_validation_error_handling_bool() -> None:
@@ -906,3 +916,165 @@ async def test_async_tool_pass_context() -> None:
assert ( assert (
await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
) )
@pytest.mark.parametrize(
"docstring",
[
None,
"",
"""
A function without an args section.
Returns:
None
""",
],
)
def test_parse_docstring_no_args_section(docstring: Optional[str]) -> None:
assert _parse_args_from_docstring(docstring) == {}
def test_parse_docstring_single_argument() -> None:
docstring = """
A function with a single argument.
Args:
param1: The first parameter.
"""
expected = {"param1": "The first parameter."}
assert _parse_args_from_docstring(docstring) == expected
@pytest.mark.parametrize(
"docstring",
[
"""
A function with multiple arguments.
Args:
param1: The first parameter.
param2: The second parameter.
""",
"""
A function with multiline argument descriptions.
Args:
param1: The first parameter.
param2: The second
parameter.
""",
"""
A function with the args section that has blank lines.
Args:
param1: The first parameter.
param2: The second parameter.
""",
"""
A function with extra sections.
Args:
param1: The first parameter.
param2: The second parameter.
Returns:
foobar
Yields:
barfoo
Raises:
baz
""",
],
)
def test_parse_docstring_multiple_arguments(docstring: str) -> None:
expected = {
"param1": "The first parameter.",
"param2": "The second parameter.",
}
assert _parse_args_from_docstring(docstring) == expected
def test_parse_docstring_args_with_multiple_colons_in_single_line() -> None:
docstring = """
A function with a colon in the description.
Args:
param1: The first parameter: with colon.
param2: The second parameter.
"""
expected = {
"param1": "The first parameter: with colon.",
"param2": "The second parameter.",
}
assert _parse_args_from_docstring(docstring) == expected
def test_parse_docstring_description() -> None:
docstring = """
A function with a
multiline description.
"""
assert (
_parse_func_description_from_docstring(docstring)
== "A function with a multiline description."
)
@pytest.mark.parametrize("section", ["Args", "Returns", "Yields", "Raises"])
def test_parse_docstring_description_multiple_sections(section: str) -> None:
docstring = f"""
A function with a
multiline description.
{section}:
foo: bar
"""
assert (
_parse_func_description_from_docstring(docstring)
== "A function with a multiline description."
)
@pytest.mark.parametrize("docstring", [None, "", "\n Args:\n bar"])
def test_parse_docstring_description_no_description(docstring: Optional[str]) -> None:
assert not _parse_func_description_from_docstring(docstring)
def foo1(a: int, b: str = "") -> float:
"""
do foo
Args:
a (int) : this
describes
a
b: this describes b
Returns:
blah
"""
return 0.0
def test_create_schema_from_function() -> None:
expected = {
"title": "fooSchema",
"description": "do foo",
"type": "object",
"properties": {
"a": {"title": "A", "description": "this describes a", "type": "integer"},
"b": {
"title": "B",
"description": "this describes b",
"type": "string",
"default": "",
},
},
"required": ["a"],
}
actual = create_schema_from_function("foo", foo1).schema()
assert expected == actual

View File

@@ -53,6 +53,27 @@ def dummy_tool() -> BaseTool:
return DummyFunction() return DummyFunction()
@pytest.fixture()
def dummy_tool_from_function() -> BaseTool:
@tool()
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
Return:
blah
Raises:
bleh
"""
return
return dummy_function
@pytest.fixture() @pytest.fixture()
def json_schema() -> Dict: def json_schema() -> Dict:
return { return {
@@ -98,6 +119,7 @@ def test_convert_to_openai_function(
pydantic: Type[BaseModel], pydantic: Type[BaseModel],
function: Callable, function: Callable,
dummy_tool: BaseTool, dummy_tool: BaseTool,
dummy_tool_from_function: BaseTool,
json_schema: Dict, json_schema: Dict,
) -> None: ) -> None:
expected = { expected = {
@@ -121,6 +143,7 @@ def test_convert_to_openai_function(
pydantic, pydantic,
function, function,
dummy_tool, dummy_tool,
dummy_tool_from_function,
json_schema, json_schema,
expected, expected,
Dummy.dummy_function, Dummy.dummy_function,