1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-05 23:28:47 +00:00

[Core] Unify function schema parsing ()

Use pydantic to infer nested schemas and all that fun.
Include bagatur's convenient docstring parser
Include annotation support


Previously we didn't adequately support many typehints in the
bind_tools() method on raw functions (like optionals/unions, nested
types, etc.)
This commit is contained in:
William FH 2024-07-03 09:55:38 -07:00 committed by GitHub
parent 2a2c0d1a94
commit 6cd56821dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 221 additions and 149 deletions
libs/core

View File

@ -28,7 +28,20 @@ from abc import ABC, abstractmethod
from contextvars import copy_context
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from typing_extensions import Annotated, get_args, get_origin
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -76,11 +89,32 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _is_annotated_type(typ: Type[Any]) -> bool:
return get_origin(typ) is Annotated
def _get_annotation_description(arg: str, arg_type: Type[Any]) -> str | None:
if _is_annotated_type(arg_type):
annotated_args = get_args(arg_type)
arg_type = annotated_args[0]
if len(annotated_args) > 1:
for annotation in annotated_args[1:]:
if isinstance(annotation, str):
return annotation
return None
def _create_subset_model(
name: str, model: Type[BaseModel], field_names: list
name: str,
model: Type[BaseModel],
field_names: list,
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
t = (
@ -89,19 +123,89 @@ def _create_subset_model(
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
if descriptions and field_name in descriptions:
field.field_info.description = descriptions[field_name]
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
def _get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
*,
filter_args: Sequence[str],
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
return {
k: schema[k]
for i, (k, param) in enumerate(valid_keys.items())
if k not in filter_args and (i > 0 or param.name not in ("self", "cls"))
}
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
docstring = inspect.getdoc(function)
if docstring:
docstring_blocks = docstring.split("\n\n")
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
elif block.startswith("Returns:") or block.startswith("Example:"):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg_descriptions[arg.strip()] = desc.strip()
elif arg:
arg_descriptions[arg.strip()] += " " + line.strip()
return description, arg_descriptions
def _infer_arg_descriptions(
fn: Callable, *, parse_docstring: bool = False
) -> Tuple[str, dict]:
"""Infer argument descriptions from a function's docstring."""
if parse_docstring:
description, arg_descriptions = _parse_python_function_docstring(fn)
else:
description = inspect.getdoc(fn) or ""
arg_descriptions = {}
if hasattr(inspect, "get_annotations"):
# This is for python < 3.10
annotations = inspect.get_annotations(fn) # type: ignore
else:
annotations = getattr(fn, "__annotations__", {})
for arg, arg_type in annotations.items():
if arg in arg_descriptions:
continue
if desc := _get_annotation_description(arg, arg_type):
arg_descriptions[arg] = desc
return description, arg_descriptions
class _SchemaConfig:
@ -114,25 +218,40 @@ class _SchemaConfig:
def create_schema_from_function(
model_name: str,
func: Callable,
*,
filter_args: Optional[Sequence[str]] = None,
parse_docstring: bool = False,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
func: Function to generate the schema from
filter_args: Optional list of arguments to exclude from the schema
parse_docstring: Whether to parse the function's docstring for descriptions
for each argument.
Returns:
A pydantic model with the same arguments as the function
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__:
del inferred_model.__fields__["run_manager"]
if "callbacks" in inferred_model.__fields__:
del inferred_model.__fields__["callbacks"]
filter_args = (
filter_args if filter_args is not None else ("run_manager", "callbacks")
)
for arg in filter_args:
if arg in inferred_model.__fields__:
del inferred_model.__fields__[arg]
description, arg_descriptions = _infer_arg_descriptions(
func, parse_docstring=parse_docstring
)
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func)
valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties)
f"{model_name}Schema",
inferred_model,
list(valid_properties),
descriptions=arg_descriptions,
fn_description=description,
)

View File

@ -2,10 +2,8 @@
from __future__ import annotations
import inspect
import logging
import uuid
from types import FunctionType, MethodType
from typing import (
TYPE_CHECKING,
Any,
@ -14,13 +12,12 @@ from typing import (
List,
Literal,
Optional,
Tuple,
Type,
Union,
cast,
)
from typing_extensions import Annotated, TypedDict, get_args, get_origin
from typing_extensions import TypedDict
from langchain_core._api import deprecated
from langchain_core.messages import (
@ -123,122 +120,6 @@ def _get_python_function_name(function: Callable) -> str:
return function.__name__
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
docstring = inspect.getdoc(function)
if docstring:
docstring_blocks = docstring.split("\n\n")
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
elif block.startswith("Returns:") or block.startswith("Example:"):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg_descriptions[arg.strip()] = desc.strip()
elif arg:
arg_descriptions[arg.strip()] += " " + line.strip()
return description, arg_descriptions
def _is_annotated_type(typ: Type[Any]) -> bool:
return get_origin(typ) is Annotated
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
"""Get JsonSchema describing a Python functions arguments.
Assumes all function arguments are of primitive types (int, float, str, bool) or
are subclasses of pydantic.BaseModel.
"""
properties = {}
annotations = inspect.getfullargspec(function).annotations
for arg, arg_type in annotations.items():
if arg == "return":
continue
if _is_annotated_type(arg_type):
annotated_args = get_args(arg_type)
arg_type = annotated_args[0]
if len(annotated_args) > 1:
for annotation in annotated_args[1:]:
if isinstance(annotation, str):
arg_descriptions[arg] = annotation
break
if (
isinstance(arg_type, type)
and hasattr(arg_type, "model_json_schema")
and callable(arg_type.model_json_schema)
):
properties[arg] = arg_type.model_json_schema()
elif (
isinstance(arg_type, type)
and hasattr(arg_type, "schema")
and callable(arg_type.schema)
):
properties[arg] = arg_type.schema()
elif (
hasattr(arg_type, "__name__")
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
):
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
elif (
hasattr(arg_type, "__dict__")
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
):
properties[arg] = {
"enum": list(arg_type.__args__),
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__],
}
else:
logger.warning(
f"Argument {arg} of type {arg_type} from function {function.__name__} "
"could not be not be converted to a JSON schema."
)
if arg in arg_descriptions:
if arg not in properties:
properties[arg] = {}
properties[arg]["description"] = arg_descriptions[arg]
return properties
def _get_python_function_required_args(function: Callable) -> List[str]:
"""Get the required arguments for a Python function."""
spec = inspect.getfullargspec(function)
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]
is_function_type = isinstance(function, FunctionType)
is_method_type = isinstance(function, MethodType)
if required and is_function_type and required[0] == "self":
required = required[1:]
elif required and is_method_type and required[0] == "cls":
required = required[1:]
return required
@deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
@ -246,23 +127,24 @@ def _get_python_function_required_args(function: Callable) -> List[str]:
)
def convert_python_function_to_openai_function(
function: Callable,
) -> Dict[str, Any]:
) -> FunctionDescription:
"""Convert a Python function to an OpenAI function-calling API compatible dict.
Assumes the Python function has type hints and a docstring with a description. If
the docstring has Google Python style argument descriptions, these will be
included as well.
"""
description, arg_descriptions = _parse_python_function_docstring(function)
return {
"name": _get_python_function_name(function),
"description": description,
"parameters": {
"type": "object",
"properties": _get_python_function_arguments(function, arg_descriptions),
"required": _get_python_function_required_args(function),
},
}
from langchain_core import tools
func_name = _get_python_function_name(function)
model = tools.create_schema_from_function(
func_name, function, filter_args=(), parse_docstring=True
)
return convert_pydantic_to_openai_function(
model,
name=func_name,
description=model.__doc__,
)
@deprecated(
@ -343,7 +225,7 @@ def convert_to_openai_function(
elif isinstance(function, BaseTool):
return cast(Dict, format_tool_to_openai_function(function))
elif callable(function):
return convert_python_function_to_openai_function(function)
return cast(Dict, convert_python_function_to_openai_function(function))
else:
raise ValueError(
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"

View File

@ -1,6 +1,7 @@
"""Test the base tool implementation."""
import asyncio
import inspect
import json
import sys
import textwrap
@ -10,6 +11,7 @@ from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union
import pytest
from typing_extensions import Annotated
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
@ -310,9 +312,10 @@ def test_structured_tool_from_function_docstring() -> None:
def foo(bar: int, baz: str) -> str:
"""Docstring
Args:
bar: int
baz: str
bar: the bar value
baz: the baz value
"""
raise NotImplementedError()
@ -328,6 +331,7 @@ def test_structured_tool_from_function_docstring() -> None:
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"description": inspect.getdoc(foo),
"title": "fooSchema",
"type": "object",
"required": ["bar", "baz"],
@ -342,6 +346,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
def foo(bar: int, baz: List[str]) -> str:
"""Docstring
Args:
bar: int
baz: List[str]
@ -352,14 +357,23 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
"baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
},
}
assert structured_tool.args_schema.schema() == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
"baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
},
},
"description": inspect.getdoc(foo),
"title": "fooSchema",
"type": "object",
"required": ["bar", "baz"],
@ -439,6 +453,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None
) -> str:
"""Docstring
Args:
bar: int
baz: str
@ -459,6 +474,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"description": inspect.getdoc(foo),
"title": "fooSchema",
"type": "object",
"required": ["bar", "baz"],
@ -675,10 +691,11 @@ def test_structured_tool_from_function() -> None:
"""Test that structured tools can be created from functions."""
def foo(bar: int, baz: str) -> str:
"""Docstring
"""Docstring thing.
Args:
bar: int
baz: str
bar: the bar value
baz: the baz value
"""
raise NotImplementedError()
@ -692,6 +709,7 @@ def test_structured_tool_from_function() -> None:
assert structured_tool.args_schema.schema() == {
"title": "fooSchema",
"type": "object",
"description": inspect.getdoc(foo),
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
@ -916,3 +934,56 @@ def test_tool_description() -> None:
foo2 = StructuredTool.from_function(foo)
assert foo2.description == "The foo."
def test_tool_arg_descriptions() -> None:
def foo(bar: str, baz: int) -> str:
"""The foo.
Args:
bar: The bar.
baz: The baz.
"""
return bar
foo1 = tool(foo)
args_schema = foo1.args_schema.schema() # type: ignore
assert args_schema == {
"title": "fooSchema",
"type": "object",
"description": inspect.getdoc(foo),
"properties": {
"bar": {"title": "Bar", "type": "string"},
"baz": {"title": "Baz", "type": "integer"},
},
"required": ["bar", "baz"],
}
def test_tool_annotated_descriptions() -> None:
def foo(
bar: Annotated[str, "this is the bar"], baz: Annotated[int, "this is the baz"]
) -> str:
"""The foo.
Returns:
The bar only.
"""
return bar
foo1 = tool(foo)
args_schema = foo1.args_schema.schema() # type: ignore
assert args_schema == {
"title": "fooSchema",
"type": "object",
"description": inspect.getdoc(foo),
"properties": {
"bar": {"title": "Bar", "type": "string", "description": "this is the bar"},
"baz": {
"title": "Baz",
"type": "integer",
"description": "this is the baz",
},
},
"required": ["bar", "baz"],
}

View File

@ -252,7 +252,7 @@ def test_function_no_params() -> None:
pass
func = convert_to_openai_function(nullary_function)
req = func["parameters"]["required"]
req = func["parameters"].get("required")
assert not req