Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
d21748047d Add tool parameter descriptions 2024-06-20 14:08:27 -07:00
2 changed files with 156 additions and 4 deletions

View File

@@ -21,14 +21,27 @@ from __future__ import annotations
import asyncio
import inspect
import logging
import textwrap
import typing
import uuid
import warnings
from abc import ABC, abstractmethod
from contextvars import copy_context
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@@ -71,13 +84,18 @@ from langchain_core.runnables.config import (
)
from langchain_core.runnables.utils import accepts_context
logger = logging.getLogger(__name__)
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _create_subset_model(
name: str, model: Type[BaseModel], field_names: list
name: str,
model: Type[BaseModel],
field_names: list,
descriptions: Optional[Mapping[str, str]] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {}
@@ -89,6 +107,10 @@ def _create_subset_model(
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
# Inject the description into the field_info
description = descriptions.get(field_name) if descriptions else None
if description:
field.field_info.description = description
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore
return rtn
@@ -104,6 +126,31 @@ def _get_filtered_args(
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
def _get_description_from_annotation(ann: Any) -> Optional[str]:
possible_descriptions = [
arg for arg in typing.get_args(ann) if isinstance(arg, str)
]
return "\n".join(possible_descriptions) if possible_descriptions else None
def _get_descriptions(func: Callable) -> Dict[str, str]:
"""Get the descriptions from a function's signature."""
descriptions = {}
for param in inspect.signature(func).parameters.values():
if param.annotation is not inspect.Parameter.empty:
try:
description = _get_description_from_annotation(param.annotation)
except Exception as e:
logger.warning(
"Could not infer tool parameter description"
f" from annotation : {repr(e)}"
)
description = None
if description:
descriptions[param.name] = description
return descriptions
class _SchemaConfig:
"""Configuration for the pydantic model."""
@@ -131,8 +178,13 @@ def create_schema_from_function(
del inferred_model.__fields__["callbacks"]
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func)
# TODO: we could pass through additional metadata here
descriptions = _get_descriptions(func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties)
f"{model_name}Schema",
inferred_model,
list(valid_properties),
descriptions=descriptions,
)

View File

@@ -10,6 +10,7 @@ from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union
import pytest
from typing_extensions import Annotated
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
@@ -24,6 +25,7 @@ from langchain_core.tools import (
Tool,
ToolException,
_create_subset_model,
create_schema_from_function,
tool,
)
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
@@ -54,7 +56,12 @@ class _MockStructuredTool(BaseTool):
args_schema: Type[BaseModel] = _MockSchema
description: str = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
def _run(
self,
arg1: int,
arg2: bool,
arg3: Optional[dict] = None,
) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
@@ -71,6 +78,33 @@ def test_structured_args() -> None:
assert structured_api.run(args) == expected_result
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or above")
def test_structured_args_description() -> None:
class _AnnotatedTool(BaseTool):
name: str = "structured_api"
description: str = "A Structured Tool"
def _run(
self,
arg1: int,
arg2: Annotated[bool, "V important"],
arg3: Optional[dict] = None,
) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
expected = {
"arg1": {"title": "Arg1", "type": "integer"},
"arg2": {"title": "Arg2", "type": "boolean", "description": "V important"},
"arg3": {"title": "Arg3", "type": "object"},
}
assert _AnnotatedTool().args == expected
def test_misannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool with the incorrect typehint raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
@@ -874,6 +908,72 @@ def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> No
foo.invoke(inputs) # type: ignore
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or above")
def test_create_schema_from_function_with_descriptions() -> None:
def foo(bar: int, baz: str) -> str:
"""Docstring
Args:
bar: int
baz: str
"""
raise NotImplementedError()
async def foo_async(bar: int, baz: str) -> str:
"""Docstring
Args:
bar: int
baz: str
"""
raise NotImplementedError()
for func in [foo, foo_async]:
schema = create_schema_from_function("foo", func)
expected = {
"title": "fooSchema",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"required": ["bar", "baz"],
}
assert schema.schema() == expected
def foo_annotated(
bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"],
) -> str:
"""Docstring
Args:
bar: int
"""
raise NotImplementedError
async def foo_async_annotated(
bar: Annotated[int, "This is bar", {"gte": 5}, "it's useful"],
) -> str:
"""Docstring
Args:
bar: int
"""
raise bar
for func in [foo_annotated, foo_async_annotated]:
schema = create_schema_from_function("foo_annotated", func)
annotated_expected = {
"title": "foo_annotatedSchema",
"type": "object",
"properties": {
"bar": {
"title": "Bar",
"type": "integer",
"description": "This is bar\nit's useful",
},
},
"required": ["bar"],
}
assert schema.schema() == annotated_expected
def test_tool_pass_context() -> None:
@tool
def foo(bar: str) -> str: