mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-05 23:28:47 +00:00
[Core] Unify function schema parsing (#23370)
Use pydantic to infer nested schemas and all that fun. Include bagatur's convenient docstring parser Include annotation support Previously we didn't adequately support many typehints in the bind_tools() method on raw functions (like optionals/unions, nested types, etc.)
This commit is contained in:
parent
2a2c0d1a94
commit
6cd56821dc
libs/core
langchain_core
tests/unit_tests
@ -28,7 +28,20 @@ from abc import ABC, abstractmethod
|
||||
from contextvars import copy_context
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated, get_args, get_origin
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -76,11 +89,32 @@ class SchemaAnnotationError(TypeError):
|
||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||
|
||||
|
||||
def _is_annotated_type(typ: Type[Any]) -> bool:
|
||||
return get_origin(typ) is Annotated
|
||||
|
||||
|
||||
def _get_annotation_description(arg: str, arg_type: Type[Any]) -> str | None:
|
||||
if _is_annotated_type(arg_type):
|
||||
annotated_args = get_args(arg_type)
|
||||
arg_type = annotated_args[0]
|
||||
if len(annotated_args) > 1:
|
||||
for annotation in annotated_args[1:]:
|
||||
if isinstance(annotation, str):
|
||||
return annotation
|
||||
return None
|
||||
|
||||
|
||||
def _create_subset_model(
|
||||
name: str, model: Type[BaseModel], field_names: list
|
||||
name: str,
|
||||
model: Type[BaseModel],
|
||||
field_names: list,
|
||||
*,
|
||||
descriptions: Optional[dict] = None,
|
||||
fn_description: Optional[str] = None,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
fields = {}
|
||||
|
||||
for field_name in field_names:
|
||||
field = model.__fields__[field_name]
|
||||
t = (
|
||||
@ -89,19 +123,89 @@ def _create_subset_model(
|
||||
if field.required and not field.allow_none
|
||||
else Optional[field.outer_type_]
|
||||
)
|
||||
if descriptions and field_name in descriptions:
|
||||
field.field_info.description = descriptions[field_name]
|
||||
fields[field_name] = (t, field.field_info)
|
||||
|
||||
rtn = create_model(name, **fields) # type: ignore
|
||||
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
||||
return rtn
|
||||
|
||||
|
||||
def _get_filtered_args(
|
||||
inferred_model: Type[BaseModel],
|
||||
func: Callable,
|
||||
*,
|
||||
filter_args: Sequence[str],
|
||||
) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
schema = inferred_model.schema()["properties"]
|
||||
valid_keys = signature(func).parameters
|
||||
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
|
||||
return {
|
||||
k: schema[k]
|
||||
for i, (k, param) in enumerate(valid_keys.items())
|
||||
if k not in filter_args and (i > 0 or param.name not in ("self", "cls"))
|
||||
}
|
||||
|
||||
|
||||
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
"""
|
||||
docstring = inspect.getdoc(function)
|
||||
if docstring:
|
||||
docstring_blocks = docstring.split("\n\n")
|
||||
descriptors = []
|
||||
args_block = None
|
||||
past_descriptors = False
|
||||
for block in docstring_blocks:
|
||||
if block.startswith("Args:"):
|
||||
args_block = block
|
||||
break
|
||||
elif block.startswith("Returns:") or block.startswith("Example:"):
|
||||
# Don't break in case Args come after
|
||||
past_descriptors = True
|
||||
elif not past_descriptors:
|
||||
descriptors.append(block)
|
||||
else:
|
||||
continue
|
||||
description = " ".join(descriptors)
|
||||
else:
|
||||
description = ""
|
||||
args_block = None
|
||||
arg_descriptions = {}
|
||||
if args_block:
|
||||
arg = None
|
||||
for line in args_block.split("\n")[1:]:
|
||||
if ":" in line:
|
||||
arg, desc = line.split(":", maxsplit=1)
|
||||
arg_descriptions[arg.strip()] = desc.strip()
|
||||
elif arg:
|
||||
arg_descriptions[arg.strip()] += " " + line.strip()
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _infer_arg_descriptions(
|
||||
fn: Callable, *, parse_docstring: bool = False
|
||||
) -> Tuple[str, dict]:
|
||||
"""Infer argument descriptions from a function's docstring."""
|
||||
if parse_docstring:
|
||||
description, arg_descriptions = _parse_python_function_docstring(fn)
|
||||
else:
|
||||
description = inspect.getdoc(fn) or ""
|
||||
arg_descriptions = {}
|
||||
if hasattr(inspect, "get_annotations"):
|
||||
# This is for python < 3.10
|
||||
annotations = inspect.get_annotations(fn) # type: ignore
|
||||
else:
|
||||
annotations = getattr(fn, "__annotations__", {})
|
||||
for arg, arg_type in annotations.items():
|
||||
if arg in arg_descriptions:
|
||||
continue
|
||||
if desc := _get_annotation_description(arg, arg_type):
|
||||
arg_descriptions[arg] = desc
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
class _SchemaConfig:
|
||||
@ -114,25 +218,40 @@ class _SchemaConfig:
|
||||
def create_schema_from_function(
|
||||
model_name: str,
|
||||
func: Callable,
|
||||
*,
|
||||
filter_args: Optional[Sequence[str]] = None,
|
||||
parse_docstring: bool = False,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature.
|
||||
Args:
|
||||
model_name: Name to assign to the generated pydandic schema
|
||||
func: Function to generate the schema from
|
||||
filter_args: Optional list of arguments to exclude from the schema
|
||||
parse_docstring: Whether to parse the function's docstring for descriptions
|
||||
for each argument.
|
||||
Returns:
|
||||
A pydantic model with the same arguments as the function
|
||||
"""
|
||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||
inferred_model = validated.model # type: ignore
|
||||
if "run_manager" in inferred_model.__fields__:
|
||||
del inferred_model.__fields__["run_manager"]
|
||||
if "callbacks" in inferred_model.__fields__:
|
||||
del inferred_model.__fields__["callbacks"]
|
||||
filter_args = (
|
||||
filter_args if filter_args is not None else ("run_manager", "callbacks")
|
||||
)
|
||||
for arg in filter_args:
|
||||
if arg in inferred_model.__fields__:
|
||||
del inferred_model.__fields__[arg]
|
||||
description, arg_descriptions = _infer_arg_descriptions(
|
||||
func, parse_docstring=parse_docstring
|
||||
)
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
valid_properties = _get_filtered_args(inferred_model, func)
|
||||
valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args)
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema", inferred_model, list(valid_properties)
|
||||
f"{model_name}Schema",
|
||||
inferred_model,
|
||||
list(valid_properties),
|
||||
descriptions=arg_descriptions,
|
||||
fn_description=description,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from types import FunctionType, MethodType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -14,13 +12,12 @@ from typing import (
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated, TypedDict, get_args, get_origin
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import (
|
||||
@ -123,122 +120,6 @@ def _get_python_function_name(function: Callable) -> str:
|
||||
return function.__name__
|
||||
|
||||
|
||||
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
"""
|
||||
docstring = inspect.getdoc(function)
|
||||
if docstring:
|
||||
docstring_blocks = docstring.split("\n\n")
|
||||
descriptors = []
|
||||
args_block = None
|
||||
past_descriptors = False
|
||||
for block in docstring_blocks:
|
||||
if block.startswith("Args:"):
|
||||
args_block = block
|
||||
break
|
||||
elif block.startswith("Returns:") or block.startswith("Example:"):
|
||||
# Don't break in case Args come after
|
||||
past_descriptors = True
|
||||
elif not past_descriptors:
|
||||
descriptors.append(block)
|
||||
else:
|
||||
continue
|
||||
description = " ".join(descriptors)
|
||||
else:
|
||||
description = ""
|
||||
args_block = None
|
||||
arg_descriptions = {}
|
||||
if args_block:
|
||||
arg = None
|
||||
for line in args_block.split("\n")[1:]:
|
||||
if ":" in line:
|
||||
arg, desc = line.split(":", maxsplit=1)
|
||||
arg_descriptions[arg.strip()] = desc.strip()
|
||||
elif arg:
|
||||
arg_descriptions[arg.strip()] += " " + line.strip()
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _is_annotated_type(typ: Type[Any]) -> bool:
|
||||
return get_origin(typ) is Annotated
|
||||
|
||||
|
||||
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
|
||||
"""Get JsonSchema describing a Python functions arguments.
|
||||
|
||||
Assumes all function arguments are of primitive types (int, float, str, bool) or
|
||||
are subclasses of pydantic.BaseModel.
|
||||
"""
|
||||
properties = {}
|
||||
annotations = inspect.getfullargspec(function).annotations
|
||||
for arg, arg_type in annotations.items():
|
||||
if arg == "return":
|
||||
continue
|
||||
|
||||
if _is_annotated_type(arg_type):
|
||||
annotated_args = get_args(arg_type)
|
||||
arg_type = annotated_args[0]
|
||||
if len(annotated_args) > 1:
|
||||
for annotation in annotated_args[1:]:
|
||||
if isinstance(annotation, str):
|
||||
arg_descriptions[arg] = annotation
|
||||
break
|
||||
if (
|
||||
isinstance(arg_type, type)
|
||||
and hasattr(arg_type, "model_json_schema")
|
||||
and callable(arg_type.model_json_schema)
|
||||
):
|
||||
properties[arg] = arg_type.model_json_schema()
|
||||
elif (
|
||||
isinstance(arg_type, type)
|
||||
and hasattr(arg_type, "schema")
|
||||
and callable(arg_type.schema)
|
||||
):
|
||||
properties[arg] = arg_type.schema()
|
||||
elif (
|
||||
hasattr(arg_type, "__name__")
|
||||
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
|
||||
):
|
||||
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
|
||||
elif (
|
||||
hasattr(arg_type, "__dict__")
|
||||
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
|
||||
):
|
||||
properties[arg] = {
|
||||
"enum": list(arg_type.__args__),
|
||||
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__],
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Argument {arg} of type {arg_type} from function {function.__name__} "
|
||||
"could not be not be converted to a JSON schema."
|
||||
)
|
||||
|
||||
if arg in arg_descriptions:
|
||||
if arg not in properties:
|
||||
properties[arg] = {}
|
||||
properties[arg]["description"] = arg_descriptions[arg]
|
||||
|
||||
return properties
|
||||
|
||||
|
||||
def _get_python_function_required_args(function: Callable) -> List[str]:
|
||||
"""Get the required arguments for a Python function."""
|
||||
spec = inspect.getfullargspec(function)
|
||||
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
|
||||
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]
|
||||
|
||||
is_function_type = isinstance(function, FunctionType)
|
||||
is_method_type = isinstance(function, MethodType)
|
||||
if required and is_function_type and required[0] == "self":
|
||||
required = required[1:]
|
||||
elif required and is_method_type and required[0] == "cls":
|
||||
required = required[1:]
|
||||
return required
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
@ -246,23 +127,24 @@ def _get_python_function_required_args(function: Callable) -> List[str]:
|
||||
)
|
||||
def convert_python_function_to_openai_function(
|
||||
function: Callable,
|
||||
) -> Dict[str, Any]:
|
||||
) -> FunctionDescription:
|
||||
"""Convert a Python function to an OpenAI function-calling API compatible dict.
|
||||
|
||||
Assumes the Python function has type hints and a docstring with a description. If
|
||||
the docstring has Google Python style argument descriptions, these will be
|
||||
included as well.
|
||||
"""
|
||||
description, arg_descriptions = _parse_python_function_docstring(function)
|
||||
return {
|
||||
"name": _get_python_function_name(function),
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": _get_python_function_arguments(function, arg_descriptions),
|
||||
"required": _get_python_function_required_args(function),
|
||||
},
|
||||
}
|
||||
from langchain_core import tools
|
||||
|
||||
func_name = _get_python_function_name(function)
|
||||
model = tools.create_schema_from_function(
|
||||
func_name, function, filter_args=(), parse_docstring=True
|
||||
)
|
||||
return convert_pydantic_to_openai_function(
|
||||
model,
|
||||
name=func_name,
|
||||
description=model.__doc__,
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
@ -343,7 +225,7 @@ def convert_to_openai_function(
|
||||
elif isinstance(function, BaseTool):
|
||||
return cast(Dict, format_tool_to_openai_function(function))
|
||||
elif callable(function):
|
||||
return convert_python_function_to_openai_function(function)
|
||||
return cast(Dict, convert_python_function_to_openai_function(function))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Test the base tool implementation."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
import textwrap
|
||||
@ -10,6 +11,7 @@ from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
@ -310,9 +312,10 @@ def test_structured_tool_from_function_docstring() -> None:
|
||||
|
||||
def foo(bar: int, baz: str) -> str:
|
||||
"""Docstring
|
||||
|
||||
Args:
|
||||
bar: int
|
||||
baz: str
|
||||
bar: the bar value
|
||||
baz: the baz value
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -328,6 +331,7 @@ def test_structured_tool_from_function_docstring() -> None:
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
},
|
||||
"description": inspect.getdoc(foo),
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
@ -342,6 +346,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
||||
|
||||
def foo(bar: int, baz: List[str]) -> str:
|
||||
"""Docstring
|
||||
|
||||
Args:
|
||||
bar: int
|
||||
baz: List[str]
|
||||
@ -352,14 +357,23 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
||||
assert structured_tool.name == "foo"
|
||||
assert structured_tool.args == {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
||||
"baz": {
|
||||
"title": "Baz",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
||||
"baz": {
|
||||
"title": "Baz",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"description": inspect.getdoc(foo),
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
@ -439,6 +453,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
|
||||
bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None
|
||||
) -> str:
|
||||
"""Docstring
|
||||
|
||||
Args:
|
||||
bar: int
|
||||
baz: str
|
||||
@ -459,6 +474,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
},
|
||||
"description": inspect.getdoc(foo),
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
@ -675,10 +691,11 @@ def test_structured_tool_from_function() -> None:
|
||||
"""Test that structured tools can be created from functions."""
|
||||
|
||||
def foo(bar: int, baz: str) -> str:
|
||||
"""Docstring
|
||||
"""Docstring thing.
|
||||
|
||||
Args:
|
||||
bar: int
|
||||
baz: str
|
||||
bar: the bar value
|
||||
baz: the baz value
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -692,6 +709,7 @@ def test_structured_tool_from_function() -> None:
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"description": inspect.getdoc(foo),
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
@ -916,3 +934,56 @@ def test_tool_description() -> None:
|
||||
|
||||
foo2 = StructuredTool.from_function(foo)
|
||||
assert foo2.description == "The foo."
|
||||
|
||||
|
||||
def test_tool_arg_descriptions() -> None:
|
||||
def foo(bar: str, baz: int) -> str:
|
||||
"""The foo.
|
||||
|
||||
Args:
|
||||
bar: The bar.
|
||||
baz: The baz.
|
||||
"""
|
||||
return bar
|
||||
|
||||
foo1 = tool(foo)
|
||||
args_schema = foo1.args_schema.schema() # type: ignore
|
||||
assert args_schema == {
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"description": inspect.getdoc(foo),
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "string"},
|
||||
"baz": {"title": "Baz", "type": "integer"},
|
||||
},
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
|
||||
|
||||
def test_tool_annotated_descriptions() -> None:
|
||||
def foo(
|
||||
bar: Annotated[str, "this is the bar"], baz: Annotated[int, "this is the baz"]
|
||||
) -> str:
|
||||
"""The foo.
|
||||
|
||||
Returns:
|
||||
The bar only.
|
||||
"""
|
||||
return bar
|
||||
|
||||
foo1 = tool(foo)
|
||||
args_schema = foo1.args_schema.schema() # type: ignore
|
||||
assert args_schema == {
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"description": inspect.getdoc(foo),
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "string", "description": "this is the bar"},
|
||||
"baz": {
|
||||
"title": "Baz",
|
||||
"type": "integer",
|
||||
"description": "this is the baz",
|
||||
},
|
||||
},
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
|
@ -252,7 +252,7 @@ def test_function_no_params() -> None:
|
||||
pass
|
||||
|
||||
func = convert_to_openai_function(nullary_function)
|
||||
req = func["parameters"]["required"]
|
||||
req = func["parameters"].get("required")
|
||||
assert not req
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user