Compare commits

...

3 Commits

Author SHA1 Message Date
Bagatur
49b68a0cb2 core[minor]: update @tool inferred schema and description 2024-05-14 23:26:58 -07:00
Bagatur
8716d065a3 wip 2024-05-14 17:57:02 -07:00
Bagatur
c15a541ccd wip 2024-05-14 17:52:13 -07:00
3 changed files with 309 additions and 61 deletions

View File

@@ -21,7 +21,7 @@ from __future__ import annotations
import asyncio
import inspect
import textwrap
import re
import uuid
import warnings
from abc import ABC, abstractmethod
@@ -37,8 +37,6 @@ from langchain_core.callbacks import (
BaseCallbackManager,
CallbackManager,
CallbackManagerForToolRun,
)
from langchain_core.callbacks.manager import (
Callbacks,
)
from langchain_core.load.serializable import Serializable
@@ -77,7 +75,12 @@ class SchemaAnnotationError(TypeError):
def _create_subset_model(
name: str, model: Type[BaseModel], field_names: list
name: str,
model: Type[BaseModel],
field_names: list,
*,
description: Optional[str] = None,
field_descriptions: Optional[Dict[str, str]] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {}
@@ -89,21 +92,70 @@ def _create_subset_model(
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
if field_descriptions and field_name in field_descriptions:
field.field_info.description = field_descriptions[field_name]
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore
rtn = create_model(name, __doc__=description, **fields) # type: ignore
return rtn
def _get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
def _get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
def _parse_args_from_docstring(docstring: Optional[str]) -> Dict[str, str]:
"""Parses the argument descriptions from a Google-style docstring.
Args:
docstring: The docstring to parse.
Returns:
dict: A dictionary where keys are argument names and values are their
descriptions.
"""
args_dict: Dict[str, str] = {}
if docstring and (args_section := re.search(r"Args:\n((?:\s*.+\n)+)", docstring)):
arg_lines = args_section.group(1).strip().split("\n")
else:
return args_dict
current_arg = None
current_desc = []
for line in arg_lines:
if match := re.match(r"\s*(\w+).*?:\s*(.*)", line):
if current_arg:
args_dict[current_arg] = " ".join(current_desc).strip()
current_arg = match.group(1)
if current_arg in ("Returns", "Yields", "Raises"):
current_arg = None
break
current_desc = [match.group(2).strip()]
else:
current_desc.append(line.strip())
if current_arg:
args_dict[current_arg] = " ".join(current_desc).strip()
return args_dict
def _parse_func_description_from_docstring(docstring: Optional[str]) -> Optional[str]:
if not docstring:
return docstring
if description_match := re.search(
r"(.*?)(?:Args|Returns|Yields|Raises)", docstring, flags=re.DOTALL
):
description = description_match.group(1)
else:
description = docstring
return " ".join(li.strip() for li in description.split("\n") if li.strip())
class _SchemaConfig:
"""Configuration for the pydantic model."""
@@ -111,14 +163,13 @@ class _SchemaConfig:
arbitrary_types_allowed: bool = True
def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
model_name: Name to assign to the generated pydantdic schema
func: Function to generate the schema from
Returns:
A pydantic model with the same arguments as the function
"""
@@ -131,8 +182,15 @@ def create_schema_from_function(
del inferred_model.__fields__["callbacks"]
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func)
docstring = getattr(func, "__doc__", "")
func_description = _parse_func_description_from_docstring(docstring)
arg_descriptions = _parse_args_from_docstring(docstring)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties)
f"{model_name}Schema",
inferred_model,
list(valid_properties),
description=func_description,
field_descriptions=arg_descriptions,
)
@@ -802,7 +860,7 @@ class StructuredTool(BaseTool):
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the result directly or as a callback
args_schema: The schema of the tool's input arguments
infer_schema: Whether to infer the schema from the function's signature
infer_schema: DEPRECATED. args_schema is always inferred if not specified.
**kwargs: Additional arguments to pass to the tool
Returns:
@@ -826,29 +884,24 @@ class StructuredTool(BaseTool):
else:
raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__
description_ = description or source_function.__doc__
if description_ is None:
raise ValueError(
"Function must have a docstring if description not provided."
)
inferred_schema = create_schema_from_function(name, source_function)
args_schema = args_schema or inferred_schema
description = (
description
or args_schema.schema().get("description")
or inferred_schema.schema().get("description")
)
if description is None:
# Only apply if using the function's docstring
description_ = textwrap.dedent(description_).strip()
# Description example:
# search_api(query: str) - Searches the API for the query.
sig = signature(source_function)
description_ = f"{name}{sig} - {description_.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
# schema name is appended within function
_args_schema = create_schema_from_function(name, source_function)
raise ValueError(
"Must specify a description or pass in an args_schema or function with "
"a docstring."
)
return cls(
name=name,
func=func,
coroutine=coroutine,
args_schema=_args_schema, # type: ignore[arg-type]
description=description_,
args_schema=args_schema,
description=description,
return_direct=return_direct,
**kwargs,
)

View File

@@ -3,7 +3,6 @@
import asyncio
import json
import sys
import textwrap
from datetime import datetime
from enum import Enum
from functools import partial
@@ -24,6 +23,9 @@ from langchain_core.tools import (
Tool,
ToolException,
_create_subset_model,
_parse_args_from_docstring,
_parse_func_description_from_docstring,
create_schema_from_function,
tool,
)
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
@@ -318,23 +320,22 @@ def test_structured_tool_from_function_docstring() -> None:
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
"bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string", "description": "str"},
}
assert structured_tool.args_schema.schema() == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
"bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string", "description": "str"},
},
"title": "fooSchema",
"description": "Docstring",
"type": "object",
"required": ["bar", "baz"],
}
prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
assert structured_tool.description == "Docstring"
def test_structured_tool_from_function_docstring_complex_args() -> None:
@@ -351,23 +352,32 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
"bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
"description": "List[str]",
},
}
assert structured_tool.args_schema.schema() == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
"bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
"description": "List[str]",
},
},
"title": "fooSchema",
"description": "Docstring",
"type": "object",
"required": ["bar", "baz"],
}
prefix = "foo(bar: int, baz: List[str]) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__).strip()
assert structured_tool.description == "Docstring"
def test_structured_tool_lambda_multi_args_schema() -> None:
@@ -451,16 +461,17 @@ def test_structured_tool_from_function_with_run_manager() -> None:
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
"bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string", "description": "str"},
}
assert structured_tool.args_schema.schema() == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
"bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string", "description": "str"},
},
"title": "fooSchema",
"description": "Docstring",
"type": "object",
"required": ["bar", "baz"],
}
@@ -553,7 +564,7 @@ def test_tool_with_kwargs() -> None:
def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing."""
# expect to throw a value error if there's no docstring
with pytest.raises(ValueError, match="Function must have a docstring"):
with pytest.raises(ValueError):
@tool
def search_api(query: str) -> str:
@@ -686,23 +697,22 @@ def test_structured_tool_from_function() -> None:
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
"bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string", "description": "str"},
}
assert structured_tool.args_schema.schema() == {
"title": "fooSchema",
"description": "Docstring",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
"bar": {"title": "Bar", "type": "integer", "description": "int"},
"baz": {"title": "Baz", "type": "string", "description": "str"},
},
"required": ["bar", "baz"],
}
prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
assert structured_tool.description == "Docstring"
def test_validation_error_handling_bool() -> None:
@@ -906,3 +916,165 @@ async def test_async_tool_pass_context() -> None:
assert (
await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
)
@pytest.mark.parametrize(
"docstring",
[
None,
"",
"""
A function without an args section.
Returns:
None
""",
],
)
def test_parse_docstring_no_args_section(docstring: Optional[str]) -> None:
assert _parse_args_from_docstring(docstring) == {}
def test_parse_docstring_single_argument() -> None:
docstring = """
A function with a single argument.
Args:
param1: The first parameter.
"""
expected = {"param1": "The first parameter."}
assert _parse_args_from_docstring(docstring) == expected
@pytest.mark.parametrize(
"docstring",
[
"""
A function with multiple arguments.
Args:
param1: The first parameter.
param2: The second parameter.
""",
"""
A function with multiline argument descriptions.
Args:
param1: The first parameter.
param2: The second
parameter.
""",
"""
A function with the args section that has blank lines.
Args:
param1: The first parameter.
param2: The second parameter.
""",
"""
A function with extra sections.
Args:
param1: The first parameter.
param2: The second parameter.
Returns:
foobar
Yields:
barfoo
Raises:
baz
""",
],
)
def test_parse_docstring_multiple_arguments(docstring: str) -> None:
expected = {
"param1": "The first parameter.",
"param2": "The second parameter.",
}
assert _parse_args_from_docstring(docstring) == expected
def test_parse_docstring_args_with_multiple_colons_in_single_line() -> None:
docstring = """
A function with a colon in the description.
Args:
param1: The first parameter: with colon.
param2: The second parameter.
"""
expected = {
"param1": "The first parameter: with colon.",
"param2": "The second parameter.",
}
assert _parse_args_from_docstring(docstring) == expected
def test_parse_docstring_description() -> None:
docstring = """
A function with a
multiline description.
"""
assert (
_parse_func_description_from_docstring(docstring)
== "A function with a multiline description."
)
@pytest.mark.parametrize("section", ["Args", "Returns", "Yields", "Raises"])
def test_parse_docstring_description_multiple_sections(section: str) -> None:
docstring = f"""
A function with a
multiline description.
{section}:
foo: bar
"""
assert (
_parse_func_description_from_docstring(docstring)
== "A function with a multiline description."
)
@pytest.mark.parametrize("docstring", [None, "", "\n Args:\n bar"])
def test_parse_docstring_description_no_description(docstring: Optional[str]) -> None:
assert not _parse_func_description_from_docstring(docstring)
def foo1(a: int, b: str = "") -> float:
"""
do foo
Args:
a (int) : this
describes
a
b: this describes b
Returns:
blah
"""
return 0.0
def test_create_schema_from_function() -> None:
expected = {
"title": "fooSchema",
"description": "do foo",
"type": "object",
"properties": {
"a": {"title": "A", "description": "this describes a", "type": "integer"},
"b": {
"title": "B",
"description": "this describes b",
"type": "string",
"default": "",
},
},
"required": ["a"],
}
actual = create_schema_from_function("foo", foo1).schema()
assert expected == actual

View File

@@ -53,6 +53,27 @@ def dummy_tool() -> BaseTool:
return DummyFunction()
@pytest.fixture()
def dummy_tool_from_function() -> BaseTool:
@tool()
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
Return:
blah
Raises:
bleh
"""
return
return dummy_function
@pytest.fixture()
def json_schema() -> Dict:
return {
@@ -98,6 +119,7 @@ def test_convert_to_openai_function(
pydantic: Type[BaseModel],
function: Callable,
dummy_tool: BaseTool,
dummy_tool_from_function: BaseTool,
json_schema: Dict,
) -> None:
expected = {
@@ -121,6 +143,7 @@ def test_convert_to_openai_function(
pydantic,
function,
dummy_tool,
dummy_tool_from_function,
json_schema,
expected,
Dummy.dummy_function,