mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-03 21:54:04 +00:00
[Core] Unify function schema parsing (#23370)
Use pydantic to infer nested schemas and all that fun. Include bagatur's convenient docstring parser Include annotation support Previously we didn't adequately support many typehints in the bind_tools() method on raw functions (like optionals/unions, nested types, etc.)
This commit is contained in:
parent
2a2c0d1a94
commit
6cd56821dc
@ -28,7 +28,20 @@ from abc import ABC, abstractmethod
|
|||||||
from contextvars import copy_context
|
from contextvars import copy_context
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing_extensions import Annotated, get_args, get_origin
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -76,11 +89,32 @@ class SchemaAnnotationError(TypeError):
|
|||||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||||
|
|
||||||
|
|
||||||
|
def _is_annotated_type(typ: Type[Any]) -> bool:
|
||||||
|
return get_origin(typ) is Annotated
|
||||||
|
|
||||||
|
|
||||||
|
def _get_annotation_description(arg: str, arg_type: Type[Any]) -> str | None:
|
||||||
|
if _is_annotated_type(arg_type):
|
||||||
|
annotated_args = get_args(arg_type)
|
||||||
|
arg_type = annotated_args[0]
|
||||||
|
if len(annotated_args) > 1:
|
||||||
|
for annotation in annotated_args[1:]:
|
||||||
|
if isinstance(annotation, str):
|
||||||
|
return annotation
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _create_subset_model(
|
def _create_subset_model(
|
||||||
name: str, model: Type[BaseModel], field_names: list
|
name: str,
|
||||||
|
model: Type[BaseModel],
|
||||||
|
field_names: list,
|
||||||
|
*,
|
||||||
|
descriptions: Optional[dict] = None,
|
||||||
|
fn_description: Optional[str] = None,
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
"""Create a pydantic model with only a subset of model's fields."""
|
"""Create a pydantic model with only a subset of model's fields."""
|
||||||
fields = {}
|
fields = {}
|
||||||
|
|
||||||
for field_name in field_names:
|
for field_name in field_names:
|
||||||
field = model.__fields__[field_name]
|
field = model.__fields__[field_name]
|
||||||
t = (
|
t = (
|
||||||
@ -89,19 +123,89 @@ def _create_subset_model(
|
|||||||
if field.required and not field.allow_none
|
if field.required and not field.allow_none
|
||||||
else Optional[field.outer_type_]
|
else Optional[field.outer_type_]
|
||||||
)
|
)
|
||||||
|
if descriptions and field_name in descriptions:
|
||||||
|
field.field_info.description = descriptions[field_name]
|
||||||
fields[field_name] = (t, field.field_info)
|
fields[field_name] = (t, field.field_info)
|
||||||
|
|
||||||
rtn = create_model(name, **fields) # type: ignore
|
rtn = create_model(name, **fields) # type: ignore
|
||||||
|
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
||||||
return rtn
|
return rtn
|
||||||
|
|
||||||
|
|
||||||
def _get_filtered_args(
|
def _get_filtered_args(
|
||||||
inferred_model: Type[BaseModel],
|
inferred_model: Type[BaseModel],
|
||||||
func: Callable,
|
func: Callable,
|
||||||
|
*,
|
||||||
|
filter_args: Sequence[str],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Get the arguments from a function's signature."""
|
"""Get the arguments from a function's signature."""
|
||||||
schema = inferred_model.schema()["properties"]
|
schema = inferred_model.schema()["properties"]
|
||||||
valid_keys = signature(func).parameters
|
valid_keys = signature(func).parameters
|
||||||
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
|
return {
|
||||||
|
k: schema[k]
|
||||||
|
for i, (k, param) in enumerate(valid_keys.items())
|
||||||
|
if k not in filter_args and (i > 0 or param.name not in ("self", "cls"))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
|
||||||
|
"""Parse the function and argument descriptions from the docstring of a function.
|
||||||
|
|
||||||
|
Assumes the function docstring follows Google Python style guide.
|
||||||
|
"""
|
||||||
|
docstring = inspect.getdoc(function)
|
||||||
|
if docstring:
|
||||||
|
docstring_blocks = docstring.split("\n\n")
|
||||||
|
descriptors = []
|
||||||
|
args_block = None
|
||||||
|
past_descriptors = False
|
||||||
|
for block in docstring_blocks:
|
||||||
|
if block.startswith("Args:"):
|
||||||
|
args_block = block
|
||||||
|
break
|
||||||
|
elif block.startswith("Returns:") or block.startswith("Example:"):
|
||||||
|
# Don't break in case Args come after
|
||||||
|
past_descriptors = True
|
||||||
|
elif not past_descriptors:
|
||||||
|
descriptors.append(block)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
description = " ".join(descriptors)
|
||||||
|
else:
|
||||||
|
description = ""
|
||||||
|
args_block = None
|
||||||
|
arg_descriptions = {}
|
||||||
|
if args_block:
|
||||||
|
arg = None
|
||||||
|
for line in args_block.split("\n")[1:]:
|
||||||
|
if ":" in line:
|
||||||
|
arg, desc = line.split(":", maxsplit=1)
|
||||||
|
arg_descriptions[arg.strip()] = desc.strip()
|
||||||
|
elif arg:
|
||||||
|
arg_descriptions[arg.strip()] += " " + line.strip()
|
||||||
|
return description, arg_descriptions
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_arg_descriptions(
|
||||||
|
fn: Callable, *, parse_docstring: bool = False
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
"""Infer argument descriptions from a function's docstring."""
|
||||||
|
if parse_docstring:
|
||||||
|
description, arg_descriptions = _parse_python_function_docstring(fn)
|
||||||
|
else:
|
||||||
|
description = inspect.getdoc(fn) or ""
|
||||||
|
arg_descriptions = {}
|
||||||
|
if hasattr(inspect, "get_annotations"):
|
||||||
|
# This is for python < 3.10
|
||||||
|
annotations = inspect.get_annotations(fn) # type: ignore
|
||||||
|
else:
|
||||||
|
annotations = getattr(fn, "__annotations__", {})
|
||||||
|
for arg, arg_type in annotations.items():
|
||||||
|
if arg in arg_descriptions:
|
||||||
|
continue
|
||||||
|
if desc := _get_annotation_description(arg, arg_type):
|
||||||
|
arg_descriptions[arg] = desc
|
||||||
|
return description, arg_descriptions
|
||||||
|
|
||||||
|
|
||||||
class _SchemaConfig:
|
class _SchemaConfig:
|
||||||
@ -114,25 +218,40 @@ class _SchemaConfig:
|
|||||||
def create_schema_from_function(
|
def create_schema_from_function(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
func: Callable,
|
func: Callable,
|
||||||
|
*,
|
||||||
|
filter_args: Optional[Sequence[str]] = None,
|
||||||
|
parse_docstring: bool = False,
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
"""Create a pydantic schema from a function's signature.
|
"""Create a pydantic schema from a function's signature.
|
||||||
Args:
|
Args:
|
||||||
model_name: Name to assign to the generated pydandic schema
|
model_name: Name to assign to the generated pydandic schema
|
||||||
func: Function to generate the schema from
|
func: Function to generate the schema from
|
||||||
|
filter_args: Optional list of arguments to exclude from the schema
|
||||||
|
parse_docstring: Whether to parse the function's docstring for descriptions
|
||||||
|
for each argument.
|
||||||
Returns:
|
Returns:
|
||||||
A pydantic model with the same arguments as the function
|
A pydantic model with the same arguments as the function
|
||||||
"""
|
"""
|
||||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||||
inferred_model = validated.model # type: ignore
|
inferred_model = validated.model # type: ignore
|
||||||
if "run_manager" in inferred_model.__fields__:
|
filter_args = (
|
||||||
del inferred_model.__fields__["run_manager"]
|
filter_args if filter_args is not None else ("run_manager", "callbacks")
|
||||||
if "callbacks" in inferred_model.__fields__:
|
)
|
||||||
del inferred_model.__fields__["callbacks"]
|
for arg in filter_args:
|
||||||
|
if arg in inferred_model.__fields__:
|
||||||
|
del inferred_model.__fields__[arg]
|
||||||
|
description, arg_descriptions = _infer_arg_descriptions(
|
||||||
|
func, parse_docstring=parse_docstring
|
||||||
|
)
|
||||||
# Pydantic adds placeholder virtual fields we need to strip
|
# Pydantic adds placeholder virtual fields we need to strip
|
||||||
valid_properties = _get_filtered_args(inferred_model, func)
|
valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args)
|
||||||
return _create_subset_model(
|
return _create_subset_model(
|
||||||
f"{model_name}Schema", inferred_model, list(valid_properties)
|
f"{model_name}Schema",
|
||||||
|
inferred_model,
|
||||||
|
list(valid_properties),
|
||||||
|
descriptions=arg_descriptions,
|
||||||
|
fn_description=description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,10 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from types import FunctionType, MethodType
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -14,13 +12,12 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import Annotated, TypedDict, get_args, get_origin
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -123,122 +120,6 @@ def _get_python_function_name(function: Callable) -> str:
|
|||||||
return function.__name__
|
return function.__name__
|
||||||
|
|
||||||
|
|
||||||
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
|
|
||||||
"""Parse the function and argument descriptions from the docstring of a function.
|
|
||||||
|
|
||||||
Assumes the function docstring follows Google Python style guide.
|
|
||||||
"""
|
|
||||||
docstring = inspect.getdoc(function)
|
|
||||||
if docstring:
|
|
||||||
docstring_blocks = docstring.split("\n\n")
|
|
||||||
descriptors = []
|
|
||||||
args_block = None
|
|
||||||
past_descriptors = False
|
|
||||||
for block in docstring_blocks:
|
|
||||||
if block.startswith("Args:"):
|
|
||||||
args_block = block
|
|
||||||
break
|
|
||||||
elif block.startswith("Returns:") or block.startswith("Example:"):
|
|
||||||
# Don't break in case Args come after
|
|
||||||
past_descriptors = True
|
|
||||||
elif not past_descriptors:
|
|
||||||
descriptors.append(block)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
description = " ".join(descriptors)
|
|
||||||
else:
|
|
||||||
description = ""
|
|
||||||
args_block = None
|
|
||||||
arg_descriptions = {}
|
|
||||||
if args_block:
|
|
||||||
arg = None
|
|
||||||
for line in args_block.split("\n")[1:]:
|
|
||||||
if ":" in line:
|
|
||||||
arg, desc = line.split(":", maxsplit=1)
|
|
||||||
arg_descriptions[arg.strip()] = desc.strip()
|
|
||||||
elif arg:
|
|
||||||
arg_descriptions[arg.strip()] += " " + line.strip()
|
|
||||||
return description, arg_descriptions
|
|
||||||
|
|
||||||
|
|
||||||
def _is_annotated_type(typ: Type[Any]) -> bool:
|
|
||||||
return get_origin(typ) is Annotated
|
|
||||||
|
|
||||||
|
|
||||||
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
|
|
||||||
"""Get JsonSchema describing a Python functions arguments.
|
|
||||||
|
|
||||||
Assumes all function arguments are of primitive types (int, float, str, bool) or
|
|
||||||
are subclasses of pydantic.BaseModel.
|
|
||||||
"""
|
|
||||||
properties = {}
|
|
||||||
annotations = inspect.getfullargspec(function).annotations
|
|
||||||
for arg, arg_type in annotations.items():
|
|
||||||
if arg == "return":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if _is_annotated_type(arg_type):
|
|
||||||
annotated_args = get_args(arg_type)
|
|
||||||
arg_type = annotated_args[0]
|
|
||||||
if len(annotated_args) > 1:
|
|
||||||
for annotation in annotated_args[1:]:
|
|
||||||
if isinstance(annotation, str):
|
|
||||||
arg_descriptions[arg] = annotation
|
|
||||||
break
|
|
||||||
if (
|
|
||||||
isinstance(arg_type, type)
|
|
||||||
and hasattr(arg_type, "model_json_schema")
|
|
||||||
and callable(arg_type.model_json_schema)
|
|
||||||
):
|
|
||||||
properties[arg] = arg_type.model_json_schema()
|
|
||||||
elif (
|
|
||||||
isinstance(arg_type, type)
|
|
||||||
and hasattr(arg_type, "schema")
|
|
||||||
and callable(arg_type.schema)
|
|
||||||
):
|
|
||||||
properties[arg] = arg_type.schema()
|
|
||||||
elif (
|
|
||||||
hasattr(arg_type, "__name__")
|
|
||||||
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
|
|
||||||
):
|
|
||||||
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
|
|
||||||
elif (
|
|
||||||
hasattr(arg_type, "__dict__")
|
|
||||||
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
|
|
||||||
):
|
|
||||||
properties[arg] = {
|
|
||||||
"enum": list(arg_type.__args__),
|
|
||||||
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Argument {arg} of type {arg_type} from function {function.__name__} "
|
|
||||||
"could not be not be converted to a JSON schema."
|
|
||||||
)
|
|
||||||
|
|
||||||
if arg in arg_descriptions:
|
|
||||||
if arg not in properties:
|
|
||||||
properties[arg] = {}
|
|
||||||
properties[arg]["description"] = arg_descriptions[arg]
|
|
||||||
|
|
||||||
return properties
|
|
||||||
|
|
||||||
|
|
||||||
def _get_python_function_required_args(function: Callable) -> List[str]:
|
|
||||||
"""Get the required arguments for a Python function."""
|
|
||||||
spec = inspect.getfullargspec(function)
|
|
||||||
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
|
|
||||||
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]
|
|
||||||
|
|
||||||
is_function_type = isinstance(function, FunctionType)
|
|
||||||
is_method_type = isinstance(function, MethodType)
|
|
||||||
if required and is_function_type and required[0] == "self":
|
|
||||||
required = required[1:]
|
|
||||||
elif required and is_method_type and required[0] == "cls":
|
|
||||||
required = required[1:]
|
|
||||||
return required
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
"0.1.16",
|
"0.1.16",
|
||||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||||
@ -246,23 +127,24 @@ def _get_python_function_required_args(function: Callable) -> List[str]:
|
|||||||
)
|
)
|
||||||
def convert_python_function_to_openai_function(
|
def convert_python_function_to_openai_function(
|
||||||
function: Callable,
|
function: Callable,
|
||||||
) -> Dict[str, Any]:
|
) -> FunctionDescription:
|
||||||
"""Convert a Python function to an OpenAI function-calling API compatible dict.
|
"""Convert a Python function to an OpenAI function-calling API compatible dict.
|
||||||
|
|
||||||
Assumes the Python function has type hints and a docstring with a description. If
|
Assumes the Python function has type hints and a docstring with a description. If
|
||||||
the docstring has Google Python style argument descriptions, these will be
|
the docstring has Google Python style argument descriptions, these will be
|
||||||
included as well.
|
included as well.
|
||||||
"""
|
"""
|
||||||
description, arg_descriptions = _parse_python_function_docstring(function)
|
from langchain_core import tools
|
||||||
return {
|
|
||||||
"name": _get_python_function_name(function),
|
func_name = _get_python_function_name(function)
|
||||||
"description": description,
|
model = tools.create_schema_from_function(
|
||||||
"parameters": {
|
func_name, function, filter_args=(), parse_docstring=True
|
||||||
"type": "object",
|
)
|
||||||
"properties": _get_python_function_arguments(function, arg_descriptions),
|
return convert_pydantic_to_openai_function(
|
||||||
"required": _get_python_function_required_args(function),
|
model,
|
||||||
},
|
name=func_name,
|
||||||
}
|
description=model.__doc__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
@ -343,7 +225,7 @@ def convert_to_openai_function(
|
|||||||
elif isinstance(function, BaseTool):
|
elif isinstance(function, BaseTool):
|
||||||
return cast(Dict, format_tool_to_openai_function(function))
|
return cast(Dict, format_tool_to_openai_function(function))
|
||||||
elif callable(function):
|
elif callable(function):
|
||||||
return convert_python_function_to_openai_function(function)
|
return cast(Dict, convert_python_function_to_openai_function(function))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
|
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Test the base tool implementation."""
|
"""Test the base tool implementation."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
@ -10,6 +11,7 @@ from functools import partial
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
@ -310,9 +312,10 @@ def test_structured_tool_from_function_docstring() -> None:
|
|||||||
|
|
||||||
def foo(bar: int, baz: str) -> str:
|
def foo(bar: int, baz: str) -> str:
|
||||||
"""Docstring
|
"""Docstring
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bar: int
|
bar: the bar value
|
||||||
baz: str
|
baz: the baz value
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -328,6 +331,7 @@ def test_structured_tool_from_function_docstring() -> None:
|
|||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string"},
|
||||||
},
|
},
|
||||||
|
"description": inspect.getdoc(foo),
|
||||||
"title": "fooSchema",
|
"title": "fooSchema",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["bar", "baz"],
|
"required": ["bar", "baz"],
|
||||||
@ -342,6 +346,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
|||||||
|
|
||||||
def foo(bar: int, baz: List[str]) -> str:
|
def foo(bar: int, baz: List[str]) -> str:
|
||||||
"""Docstring
|
"""Docstring
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bar: int
|
bar: int
|
||||||
baz: List[str]
|
baz: List[str]
|
||||||
@ -352,14 +357,23 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
|||||||
assert structured_tool.name == "foo"
|
assert structured_tool.name == "foo"
|
||||||
assert structured_tool.args == {
|
assert structured_tool.args == {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
"baz": {
|
||||||
|
"title": "Baz",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert structured_tool.args_schema.schema() == {
|
assert structured_tool.args_schema.schema() == {
|
||||||
"properties": {
|
"properties": {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
"baz": {
|
||||||
|
"title": "Baz",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
"description": inspect.getdoc(foo),
|
||||||
"title": "fooSchema",
|
"title": "fooSchema",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["bar", "baz"],
|
"required": ["bar", "baz"],
|
||||||
@ -439,6 +453,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
|
|||||||
bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None
|
bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Docstring
|
"""Docstring
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bar: int
|
bar: int
|
||||||
baz: str
|
baz: str
|
||||||
@ -459,6 +474,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
|
|||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string"},
|
||||||
},
|
},
|
||||||
|
"description": inspect.getdoc(foo),
|
||||||
"title": "fooSchema",
|
"title": "fooSchema",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["bar", "baz"],
|
"required": ["bar", "baz"],
|
||||||
@ -675,10 +691,11 @@ def test_structured_tool_from_function() -> None:
|
|||||||
"""Test that structured tools can be created from functions."""
|
"""Test that structured tools can be created from functions."""
|
||||||
|
|
||||||
def foo(bar: int, baz: str) -> str:
|
def foo(bar: int, baz: str) -> str:
|
||||||
"""Docstring
|
"""Docstring thing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bar: int
|
bar: the bar value
|
||||||
baz: str
|
baz: the baz value
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -692,6 +709,7 @@ def test_structured_tool_from_function() -> None:
|
|||||||
assert structured_tool.args_schema.schema() == {
|
assert structured_tool.args_schema.schema() == {
|
||||||
"title": "fooSchema",
|
"title": "fooSchema",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
"description": inspect.getdoc(foo),
|
||||||
"properties": {
|
"properties": {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string"},
|
||||||
@ -916,3 +934,56 @@ def test_tool_description() -> None:
|
|||||||
|
|
||||||
foo2 = StructuredTool.from_function(foo)
|
foo2 = StructuredTool.from_function(foo)
|
||||||
assert foo2.description == "The foo."
|
assert foo2.description == "The foo."
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_arg_descriptions() -> None:
|
||||||
|
def foo(bar: str, baz: int) -> str:
|
||||||
|
"""The foo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bar: The bar.
|
||||||
|
baz: The baz.
|
||||||
|
"""
|
||||||
|
return bar
|
||||||
|
|
||||||
|
foo1 = tool(foo)
|
||||||
|
args_schema = foo1.args_schema.schema() # type: ignore
|
||||||
|
assert args_schema == {
|
||||||
|
"title": "fooSchema",
|
||||||
|
"type": "object",
|
||||||
|
"description": inspect.getdoc(foo),
|
||||||
|
"properties": {
|
||||||
|
"bar": {"title": "Bar", "type": "string"},
|
||||||
|
"baz": {"title": "Baz", "type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["bar", "baz"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_annotated_descriptions() -> None:
|
||||||
|
def foo(
|
||||||
|
bar: Annotated[str, "this is the bar"], baz: Annotated[int, "this is the baz"]
|
||||||
|
) -> str:
|
||||||
|
"""The foo.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The bar only.
|
||||||
|
"""
|
||||||
|
return bar
|
||||||
|
|
||||||
|
foo1 = tool(foo)
|
||||||
|
args_schema = foo1.args_schema.schema() # type: ignore
|
||||||
|
assert args_schema == {
|
||||||
|
"title": "fooSchema",
|
||||||
|
"type": "object",
|
||||||
|
"description": inspect.getdoc(foo),
|
||||||
|
"properties": {
|
||||||
|
"bar": {"title": "Bar", "type": "string", "description": "this is the bar"},
|
||||||
|
"baz": {
|
||||||
|
"title": "Baz",
|
||||||
|
"type": "integer",
|
||||||
|
"description": "this is the baz",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["bar", "baz"],
|
||||||
|
}
|
||||||
|
@ -252,7 +252,7 @@ def test_function_no_params() -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
func = convert_to_openai_function(nullary_function)
|
func = convert_to_openai_function(nullary_function)
|
||||||
req = func["parameters"]["required"]
|
req = func["parameters"].get("required")
|
||||||
assert not req
|
assert not req
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user