[Core] Unify function schema parsing (#23370)

Use pydantic to infer nested schemas and all that fun.
Include bagatur's convenient docstring parser
Include annotation support


Previously we didn't adequately support many typehints in the
bind_tools() method on raw functions (like optionals/unions, nested
types, etc.)
This commit is contained in:
William FH 2024-07-03 09:55:38 -07:00 committed by GitHub
parent 2a2c0d1a94
commit 6cd56821dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 221 additions and 149 deletions

View File

@ -28,7 +28,20 @@ from abc import ABC, abstractmethod
from contextvars import copy_context
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from typing_extensions import Annotated, get_args, get_origin
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -76,11 +89,32 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _is_annotated_type(typ: Type[Any]) -> bool:
return get_origin(typ) is Annotated
def _get_annotation_description(arg: str, arg_type: Type[Any]) -> str | None:
if _is_annotated_type(arg_type):
annotated_args = get_args(arg_type)
arg_type = annotated_args[0]
if len(annotated_args) > 1:
for annotation in annotated_args[1:]:
if isinstance(annotation, str):
return annotation
return None
def _create_subset_model(
name: str, model: Type[BaseModel], field_names: list
name: str,
model: Type[BaseModel],
field_names: list,
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
t = (
@ -89,19 +123,89 @@ def _create_subset_model(
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
if descriptions and field_name in descriptions:
field.field_info.description = descriptions[field_name]
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
def _get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
*,
filter_args: Sequence[str],
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
return {
k: schema[k]
for i, (k, param) in enumerate(valid_keys.items())
if k not in filter_args and (i > 0 or param.name not in ("self", "cls"))
}
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
docstring = inspect.getdoc(function)
if docstring:
docstring_blocks = docstring.split("\n\n")
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
elif block.startswith("Returns:") or block.startswith("Example:"):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg_descriptions[arg.strip()] = desc.strip()
elif arg:
arg_descriptions[arg.strip()] += " " + line.strip()
return description, arg_descriptions
def _infer_arg_descriptions(
fn: Callable, *, parse_docstring: bool = False
) -> Tuple[str, dict]:
"""Infer argument descriptions from a function's docstring."""
if parse_docstring:
description, arg_descriptions = _parse_python_function_docstring(fn)
else:
description = inspect.getdoc(fn) or ""
arg_descriptions = {}
if hasattr(inspect, "get_annotations"):
# This is for python < 3.10
annotations = inspect.get_annotations(fn) # type: ignore
else:
annotations = getattr(fn, "__annotations__", {})
for arg, arg_type in annotations.items():
if arg in arg_descriptions:
continue
if desc := _get_annotation_description(arg, arg_type):
arg_descriptions[arg] = desc
return description, arg_descriptions
class _SchemaConfig:
@ -114,25 +218,40 @@ class _SchemaConfig:
def create_schema_from_function(
model_name: str,
func: Callable,
*,
filter_args: Optional[Sequence[str]] = None,
parse_docstring: bool = False,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
func: Function to generate the schema from
filter_args: Optional list of arguments to exclude from the schema
parse_docstring: Whether to parse the function's docstring for descriptions
for each argument.
Returns:
A pydantic model with the same arguments as the function
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__:
del inferred_model.__fields__["run_manager"]
if "callbacks" in inferred_model.__fields__:
del inferred_model.__fields__["callbacks"]
filter_args = (
filter_args if filter_args is not None else ("run_manager", "callbacks")
)
for arg in filter_args:
if arg in inferred_model.__fields__:
del inferred_model.__fields__[arg]
description, arg_descriptions = _infer_arg_descriptions(
func, parse_docstring=parse_docstring
)
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func)
valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties)
f"{model_name}Schema",
inferred_model,
list(valid_properties),
descriptions=arg_descriptions,
fn_description=description,
)

View File

@ -2,10 +2,8 @@
from __future__ import annotations
import inspect
import logging
import uuid
from types import FunctionType, MethodType
from typing import (
TYPE_CHECKING,
Any,
@ -14,13 +12,12 @@ from typing import (
List,
Literal,
Optional,
Tuple,
Type,
Union,
cast,
)
from typing_extensions import Annotated, TypedDict, get_args, get_origin
from typing_extensions import TypedDict
from langchain_core._api import deprecated
from langchain_core.messages import (
@ -123,122 +120,6 @@ def _get_python_function_name(function: Callable) -> str:
return function.__name__
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
docstring = inspect.getdoc(function)
if docstring:
docstring_blocks = docstring.split("\n\n")
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
elif block.startswith("Returns:") or block.startswith("Example:"):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg_descriptions[arg.strip()] = desc.strip()
elif arg:
arg_descriptions[arg.strip()] += " " + line.strip()
return description, arg_descriptions
def _is_annotated_type(typ: Type[Any]) -> bool:
return get_origin(typ) is Annotated
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
"""Get JsonSchema describing a Python functions arguments.
Assumes all function arguments are of primitive types (int, float, str, bool) or
are subclasses of pydantic.BaseModel.
"""
properties = {}
annotations = inspect.getfullargspec(function).annotations
for arg, arg_type in annotations.items():
if arg == "return":
continue
if _is_annotated_type(arg_type):
annotated_args = get_args(arg_type)
arg_type = annotated_args[0]
if len(annotated_args) > 1:
for annotation in annotated_args[1:]:
if isinstance(annotation, str):
arg_descriptions[arg] = annotation
break
if (
isinstance(arg_type, type)
and hasattr(arg_type, "model_json_schema")
and callable(arg_type.model_json_schema)
):
properties[arg] = arg_type.model_json_schema()
elif (
isinstance(arg_type, type)
and hasattr(arg_type, "schema")
and callable(arg_type.schema)
):
properties[arg] = arg_type.schema()
elif (
hasattr(arg_type, "__name__")
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
):
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
elif (
hasattr(arg_type, "__dict__")
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
):
properties[arg] = {
"enum": list(arg_type.__args__),
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__],
}
else:
logger.warning(
f"Argument {arg} of type {arg_type} from function {function.__name__} "
"could not be not be converted to a JSON schema."
)
if arg in arg_descriptions:
if arg not in properties:
properties[arg] = {}
properties[arg]["description"] = arg_descriptions[arg]
return properties
def _get_python_function_required_args(function: Callable) -> List[str]:
"""Get the required arguments for a Python function."""
spec = inspect.getfullargspec(function)
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]
is_function_type = isinstance(function, FunctionType)
is_method_type = isinstance(function, MethodType)
if required and is_function_type and required[0] == "self":
required = required[1:]
elif required and is_method_type and required[0] == "cls":
required = required[1:]
return required
@deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
@ -246,23 +127,24 @@ def _get_python_function_required_args(function: Callable) -> List[str]:
)
def convert_python_function_to_openai_function(
function: Callable,
) -> Dict[str, Any]:
) -> FunctionDescription:
"""Convert a Python function to an OpenAI function-calling API compatible dict.
Assumes the Python function has type hints and a docstring with a description. If
the docstring has Google Python style argument descriptions, these will be
included as well.
"""
description, arg_descriptions = _parse_python_function_docstring(function)
return {
"name": _get_python_function_name(function),
"description": description,
"parameters": {
"type": "object",
"properties": _get_python_function_arguments(function, arg_descriptions),
"required": _get_python_function_required_args(function),
},
}
from langchain_core import tools
func_name = _get_python_function_name(function)
model = tools.create_schema_from_function(
func_name, function, filter_args=(), parse_docstring=True
)
return convert_pydantic_to_openai_function(
model,
name=func_name,
description=model.__doc__,
)
@deprecated(
@ -343,7 +225,7 @@ def convert_to_openai_function(
elif isinstance(function, BaseTool):
return cast(Dict, format_tool_to_openai_function(function))
elif callable(function):
return convert_python_function_to_openai_function(function)
return cast(Dict, convert_python_function_to_openai_function(function))
else:
raise ValueError(
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"

View File

@ -1,6 +1,7 @@
"""Test the base tool implementation."""
import asyncio
import inspect
import json
import sys
import textwrap
@ -10,6 +11,7 @@ from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union
import pytest
from typing_extensions import Annotated
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
@ -310,9 +312,10 @@ def test_structured_tool_from_function_docstring() -> None:
def foo(bar: int, baz: str) -> str:
"""Docstring
Args:
bar: int
baz: str
bar: the bar value
baz: the baz value
"""
raise NotImplementedError()
@ -328,6 +331,7 @@ def test_structured_tool_from_function_docstring() -> None:
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"description": inspect.getdoc(foo),
"title": "fooSchema",
"type": "object",
"required": ["bar", "baz"],
@ -342,6 +346,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
def foo(bar: int, baz: List[str]) -> str:
"""Docstring
Args:
bar: int
baz: List[str]
@ -352,14 +357,23 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
"baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
},
}
assert structured_tool.args_schema.schema() == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
"baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
},
},
"description": inspect.getdoc(foo),
"title": "fooSchema",
"type": "object",
"required": ["bar", "baz"],
@ -439,6 +453,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None
) -> str:
"""Docstring
Args:
bar: int
baz: str
@ -459,6 +474,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"description": inspect.getdoc(foo),
"title": "fooSchema",
"type": "object",
"required": ["bar", "baz"],
@ -675,10 +691,11 @@ def test_structured_tool_from_function() -> None:
"""Test that structured tools can be created from functions."""
def foo(bar: int, baz: str) -> str:
"""Docstring
"""Docstring thing.
Args:
bar: int
baz: str
bar: the bar value
baz: the baz value
"""
raise NotImplementedError()
@ -692,6 +709,7 @@ def test_structured_tool_from_function() -> None:
assert structured_tool.args_schema.schema() == {
"title": "fooSchema",
"type": "object",
"description": inspect.getdoc(foo),
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
@ -916,3 +934,56 @@ def test_tool_description() -> None:
foo2 = StructuredTool.from_function(foo)
assert foo2.description == "The foo."
def test_tool_arg_descriptions() -> None:
def foo(bar: str, baz: int) -> str:
"""The foo.
Args:
bar: The bar.
baz: The baz.
"""
return bar
foo1 = tool(foo)
args_schema = foo1.args_schema.schema() # type: ignore
assert args_schema == {
"title": "fooSchema",
"type": "object",
"description": inspect.getdoc(foo),
"properties": {
"bar": {"title": "Bar", "type": "string"},
"baz": {"title": "Baz", "type": "integer"},
},
"required": ["bar", "baz"],
}
def test_tool_annotated_descriptions() -> None:
def foo(
bar: Annotated[str, "this is the bar"], baz: Annotated[int, "this is the baz"]
) -> str:
"""The foo.
Returns:
The bar only.
"""
return bar
foo1 = tool(foo)
args_schema = foo1.args_schema.schema() # type: ignore
assert args_schema == {
"title": "fooSchema",
"type": "object",
"description": inspect.getdoc(foo),
"properties": {
"bar": {"title": "Bar", "type": "string", "description": "this is the bar"},
"baz": {
"title": "Baz",
"type": "integer",
"description": "this is the baz",
},
},
"required": ["bar", "baz"],
}

View File

@ -252,7 +252,7 @@ def test_function_no_params() -> None:
pass
func = convert_to_openai_function(nullary_function)
req = func["parameters"]["required"]
req = func["parameters"].get("required")
assert not req