[Core] Unify function schema parsing (#23370)

Use pydantic to infer nested schemas and all that fun.
Include bagatur's convenient docstring parser
Include annotation support


Previously we didn't adequately support many typehints in the
bind_tools() method on raw functions (like optionals/unions, nested
types, etc.)
This commit is contained in:
William FH 2024-07-03 09:55:38 -07:00 committed by GitHub
parent 2a2c0d1a94
commit 6cd56821dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 221 additions and 149 deletions

View File

@ -28,7 +28,20 @@ from abc import ABC, abstractmethod
from contextvars import copy_context from contextvars import copy_context
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from typing_extensions import Annotated, get_args, get_origin
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -76,11 +89,32 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation.""" """Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _is_annotated_type(typ: Type[Any]) -> bool:
return get_origin(typ) is Annotated
def _get_annotation_description(arg: str, arg_type: Type[Any]) -> str | None:
if _is_annotated_type(arg_type):
annotated_args = get_args(arg_type)
arg_type = annotated_args[0]
if len(annotated_args) > 1:
for annotation in annotated_args[1:]:
if isinstance(annotation, str):
return annotation
return None
def _create_subset_model( def _create_subset_model(
name: str, model: Type[BaseModel], field_names: list name: str,
model: Type[BaseModel],
field_names: list,
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]: ) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields.""" """Create a pydantic model with only a subset of model's fields."""
fields = {} fields = {}
for field_name in field_names: for field_name in field_names:
field = model.__fields__[field_name] field = model.__fields__[field_name]
t = ( t = (
@ -89,19 +123,89 @@ def _create_subset_model(
if field.required and not field.allow_none if field.required and not field.allow_none
else Optional[field.outer_type_] else Optional[field.outer_type_]
) )
if descriptions and field_name in descriptions:
field.field_info.description = descriptions[field_name]
fields[field_name] = (t, field.field_info) fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore rtn = create_model(name, **fields) # type: ignore
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn return rtn
def _get_filtered_args( def _get_filtered_args(
inferred_model: Type[BaseModel], inferred_model: Type[BaseModel],
func: Callable, func: Callable,
*,
filter_args: Sequence[str],
) -> dict: ) -> dict:
"""Get the arguments from a function's signature.""" """Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"] schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} return {
k: schema[k]
for i, (k, param) in enumerate(valid_keys.items())
if k not in filter_args and (i > 0 or param.name not in ("self", "cls"))
}
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
docstring = inspect.getdoc(function)
if docstring:
docstring_blocks = docstring.split("\n\n")
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
elif block.startswith("Returns:") or block.startswith("Example:"):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg_descriptions[arg.strip()] = desc.strip()
elif arg:
arg_descriptions[arg.strip()] += " " + line.strip()
return description, arg_descriptions
def _infer_arg_descriptions(
fn: Callable, *, parse_docstring: bool = False
) -> Tuple[str, dict]:
"""Infer argument descriptions from a function's docstring."""
if parse_docstring:
description, arg_descriptions = _parse_python_function_docstring(fn)
else:
description = inspect.getdoc(fn) or ""
arg_descriptions = {}
if hasattr(inspect, "get_annotations"):
# This is for python < 3.10
annotations = inspect.get_annotations(fn) # type: ignore
else:
annotations = getattr(fn, "__annotations__", {})
for arg, arg_type in annotations.items():
if arg in arg_descriptions:
continue
if desc := _get_annotation_description(arg, arg_type):
arg_descriptions[arg] = desc
return description, arg_descriptions
class _SchemaConfig: class _SchemaConfig:
@ -114,25 +218,40 @@ class _SchemaConfig:
def create_schema_from_function( def create_schema_from_function(
model_name: str, model_name: str,
func: Callable, func: Callable,
*,
filter_args: Optional[Sequence[str]] = None,
parse_docstring: bool = False,
) -> Type[BaseModel]: ) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature. """Create a pydantic schema from a function's signature.
Args: Args:
model_name: Name to assign to the generated pydandic schema model_name: Name to assign to the generated pydandic schema
func: Function to generate the schema from func: Function to generate the schema from
filter_args: Optional list of arguments to exclude from the schema
parse_docstring: Whether to parse the function's docstring for descriptions
for each argument.
Returns: Returns:
A pydantic model with the same arguments as the function A pydantic model with the same arguments as the function
""" """
# https://docs.pydantic.dev/latest/usage/validation_decorator/ # https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
inferred_model = validated.model # type: ignore inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__: filter_args = (
del inferred_model.__fields__["run_manager"] filter_args if filter_args is not None else ("run_manager", "callbacks")
if "callbacks" in inferred_model.__fields__: )
del inferred_model.__fields__["callbacks"] for arg in filter_args:
if arg in inferred_model.__fields__:
del inferred_model.__fields__[arg]
description, arg_descriptions = _infer_arg_descriptions(
func, parse_docstring=parse_docstring
)
# Pydantic adds placeholder virtual fields we need to strip # Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func) valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args)
return _create_subset_model( return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties) f"{model_name}Schema",
inferred_model,
list(valid_properties),
descriptions=arg_descriptions,
fn_description=description,
) )

View File

@ -2,10 +2,8 @@
from __future__ import annotations from __future__ import annotations
import inspect
import logging import logging
import uuid import uuid
from types import FunctionType, MethodType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -14,13 +12,12 @@ from typing import (
List, List,
Literal, Literal,
Optional, Optional,
Tuple,
Type, Type,
Union, Union,
cast, cast,
) )
from typing_extensions import Annotated, TypedDict, get_args, get_origin from typing_extensions import TypedDict
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.messages import ( from langchain_core.messages import (
@ -123,122 +120,6 @@ def _get_python_function_name(function: Callable) -> str:
return function.__name__ return function.__name__
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
docstring = inspect.getdoc(function)
if docstring:
docstring_blocks = docstring.split("\n\n")
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
elif block.startswith("Returns:") or block.startswith("Example:"):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg_descriptions[arg.strip()] = desc.strip()
elif arg:
arg_descriptions[arg.strip()] += " " + line.strip()
return description, arg_descriptions
def _is_annotated_type(typ: Type[Any]) -> bool:
return get_origin(typ) is Annotated
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
"""Get JsonSchema describing a Python functions arguments.
Assumes all function arguments are of primitive types (int, float, str, bool) or
are subclasses of pydantic.BaseModel.
"""
properties = {}
annotations = inspect.getfullargspec(function).annotations
for arg, arg_type in annotations.items():
if arg == "return":
continue
if _is_annotated_type(arg_type):
annotated_args = get_args(arg_type)
arg_type = annotated_args[0]
if len(annotated_args) > 1:
for annotation in annotated_args[1:]:
if isinstance(annotation, str):
arg_descriptions[arg] = annotation
break
if (
isinstance(arg_type, type)
and hasattr(arg_type, "model_json_schema")
and callable(arg_type.model_json_schema)
):
properties[arg] = arg_type.model_json_schema()
elif (
isinstance(arg_type, type)
and hasattr(arg_type, "schema")
and callable(arg_type.schema)
):
properties[arg] = arg_type.schema()
elif (
hasattr(arg_type, "__name__")
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
):
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
elif (
hasattr(arg_type, "__dict__")
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
):
properties[arg] = {
"enum": list(arg_type.__args__),
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__],
}
else:
logger.warning(
f"Argument {arg} of type {arg_type} from function {function.__name__} "
"could not be not be converted to a JSON schema."
)
if arg in arg_descriptions:
if arg not in properties:
properties[arg] = {}
properties[arg]["description"] = arg_descriptions[arg]
return properties
def _get_python_function_required_args(function: Callable) -> List[str]:
"""Get the required arguments for a Python function."""
spec = inspect.getfullargspec(function)
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]
is_function_type = isinstance(function, FunctionType)
is_method_type = isinstance(function, MethodType)
if required and is_function_type and required[0] == "self":
required = required[1:]
elif required and is_method_type and required[0] == "cls":
required = required[1:]
return required
@deprecated( @deprecated(
"0.1.16", "0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()", alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
@ -246,23 +127,24 @@ def _get_python_function_required_args(function: Callable) -> List[str]:
) )
def convert_python_function_to_openai_function( def convert_python_function_to_openai_function(
function: Callable, function: Callable,
) -> Dict[str, Any]: ) -> FunctionDescription:
"""Convert a Python function to an OpenAI function-calling API compatible dict. """Convert a Python function to an OpenAI function-calling API compatible dict.
Assumes the Python function has type hints and a docstring with a description. If Assumes the Python function has type hints and a docstring with a description. If
the docstring has Google Python style argument descriptions, these will be the docstring has Google Python style argument descriptions, these will be
included as well. included as well.
""" """
description, arg_descriptions = _parse_python_function_docstring(function) from langchain_core import tools
return {
"name": _get_python_function_name(function), func_name = _get_python_function_name(function)
"description": description, model = tools.create_schema_from_function(
"parameters": { func_name, function, filter_args=(), parse_docstring=True
"type": "object", )
"properties": _get_python_function_arguments(function, arg_descriptions), return convert_pydantic_to_openai_function(
"required": _get_python_function_required_args(function), model,
}, name=func_name,
} description=model.__doc__,
)
@deprecated( @deprecated(
@ -343,7 +225,7 @@ def convert_to_openai_function(
elif isinstance(function, BaseTool): elif isinstance(function, BaseTool):
return cast(Dict, format_tool_to_openai_function(function)) return cast(Dict, format_tool_to_openai_function(function))
elif callable(function): elif callable(function):
return convert_python_function_to_openai_function(function) return cast(Dict, convert_python_function_to_openai_function(function))
else: else:
raise ValueError( raise ValueError(
f"Unsupported function\n\n{function}\n\nFunctions must be passed in" f"Unsupported function\n\n{function}\n\nFunctions must be passed in"

View File

@ -1,6 +1,7 @@
"""Test the base tool implementation.""" """Test the base tool implementation."""
import asyncio import asyncio
import inspect
import json import json
import sys import sys
import textwrap import textwrap
@ -10,6 +11,7 @@ from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Optional, Type, Union
import pytest import pytest
from typing_extensions import Annotated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
@ -310,9 +312,10 @@ def test_structured_tool_from_function_docstring() -> None:
def foo(bar: int, baz: str) -> str: def foo(bar: int, baz: str) -> str:
"""Docstring """Docstring
Args: Args:
bar: int bar: the bar value
baz: str baz: the baz value
""" """
raise NotImplementedError() raise NotImplementedError()
@ -328,6 +331,7 @@ def test_structured_tool_from_function_docstring() -> None:
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string"},
}, },
"description": inspect.getdoc(foo),
"title": "fooSchema", "title": "fooSchema",
"type": "object", "type": "object",
"required": ["bar", "baz"], "required": ["bar", "baz"],
@ -342,6 +346,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
def foo(bar: int, baz: List[str]) -> str: def foo(bar: int, baz: List[str]) -> str:
"""Docstring """Docstring
Args: Args:
bar: int bar: int
baz: List[str] baz: List[str]
@ -352,14 +357,23 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
assert structured_tool.name == "foo" assert structured_tool.name == "foo"
assert structured_tool.args == { assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, "baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
},
} }
assert structured_tool.args_schema.schema() == { assert structured_tool.args_schema.schema() == {
"properties": { "properties": {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, "baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
},
}, },
"description": inspect.getdoc(foo),
"title": "fooSchema", "title": "fooSchema",
"type": "object", "type": "object",
"required": ["bar", "baz"], "required": ["bar", "baz"],
@ -439,6 +453,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None
) -> str: ) -> str:
"""Docstring """Docstring
Args: Args:
bar: int bar: int
baz: str baz: str
@ -459,6 +474,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string"},
}, },
"description": inspect.getdoc(foo),
"title": "fooSchema", "title": "fooSchema",
"type": "object", "type": "object",
"required": ["bar", "baz"], "required": ["bar", "baz"],
@ -675,10 +691,11 @@ def test_structured_tool_from_function() -> None:
"""Test that structured tools can be created from functions.""" """Test that structured tools can be created from functions."""
def foo(bar: int, baz: str) -> str: def foo(bar: int, baz: str) -> str:
"""Docstring """Docstring thing.
Args: Args:
bar: int bar: the bar value
baz: str baz: the baz value
""" """
raise NotImplementedError() raise NotImplementedError()
@ -692,6 +709,7 @@ def test_structured_tool_from_function() -> None:
assert structured_tool.args_schema.schema() == { assert structured_tool.args_schema.schema() == {
"title": "fooSchema", "title": "fooSchema",
"type": "object", "type": "object",
"description": inspect.getdoc(foo),
"properties": { "properties": {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string"},
@ -916,3 +934,56 @@ def test_tool_description() -> None:
foo2 = StructuredTool.from_function(foo) foo2 = StructuredTool.from_function(foo)
assert foo2.description == "The foo." assert foo2.description == "The foo."
def test_tool_arg_descriptions() -> None:
def foo(bar: str, baz: int) -> str:
"""The foo.
Args:
bar: The bar.
baz: The baz.
"""
return bar
foo1 = tool(foo)
args_schema = foo1.args_schema.schema() # type: ignore
assert args_schema == {
"title": "fooSchema",
"type": "object",
"description": inspect.getdoc(foo),
"properties": {
"bar": {"title": "Bar", "type": "string"},
"baz": {"title": "Baz", "type": "integer"},
},
"required": ["bar", "baz"],
}
def test_tool_annotated_descriptions() -> None:
def foo(
bar: Annotated[str, "this is the bar"], baz: Annotated[int, "this is the baz"]
) -> str:
"""The foo.
Returns:
The bar only.
"""
return bar
foo1 = tool(foo)
args_schema = foo1.args_schema.schema() # type: ignore
assert args_schema == {
"title": "fooSchema",
"type": "object",
"description": inspect.getdoc(foo),
"properties": {
"bar": {"title": "Bar", "type": "string", "description": "this is the bar"},
"baz": {
"title": "Baz",
"type": "integer",
"description": "this is the baz",
},
},
"required": ["bar", "baz"],
}

View File

@ -252,7 +252,7 @@ def test_function_no_params() -> None:
pass pass
func = convert_to_openai_function(nullary_function) func = convert_to_openai_function(nullary_function)
req = func["parameters"]["required"] req = func["parameters"].get("required")
assert not req assert not req