mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Compare commits
3 Commits
langchain-
...
bagatur/pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49b68a0cb2 | ||
|
|
8716d065a3 | ||
|
|
c15a541ccd |
@@ -21,7 +21,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -37,8 +37,6 @@ from langchain_core.callbacks import (
|
|||||||
BaseCallbackManager,
|
BaseCallbackManager,
|
||||||
CallbackManager,
|
CallbackManager,
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
|
||||||
from langchain_core.callbacks.manager import (
|
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
@@ -77,7 +75,12 @@ class SchemaAnnotationError(TypeError):
|
|||||||
|
|
||||||
|
|
||||||
def _create_subset_model(
|
def _create_subset_model(
|
||||||
name: str, model: Type[BaseModel], field_names: list
|
name: str,
|
||||||
|
model: Type[BaseModel],
|
||||||
|
field_names: list,
|
||||||
|
*,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
field_descriptions: Optional[Dict[str, str]] = None,
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
"""Create a pydantic model with only a subset of model's fields."""
|
"""Create a pydantic model with only a subset of model's fields."""
|
||||||
fields = {}
|
fields = {}
|
||||||
@@ -89,21 +92,70 @@ def _create_subset_model(
|
|||||||
if field.required and not field.allow_none
|
if field.required and not field.allow_none
|
||||||
else Optional[field.outer_type_]
|
else Optional[field.outer_type_]
|
||||||
)
|
)
|
||||||
|
if field_descriptions and field_name in field_descriptions:
|
||||||
|
field.field_info.description = field_descriptions[field_name]
|
||||||
fields[field_name] = (t, field.field_info)
|
fields[field_name] = (t, field.field_info)
|
||||||
rtn = create_model(name, **fields) # type: ignore
|
rtn = create_model(name, __doc__=description, **fields) # type: ignore
|
||||||
return rtn
|
return rtn
|
||||||
|
|
||||||
|
|
||||||
def _get_filtered_args(
|
def _get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
||||||
inferred_model: Type[BaseModel],
|
|
||||||
func: Callable,
|
|
||||||
) -> dict:
|
|
||||||
"""Get the arguments from a function's signature."""
|
"""Get the arguments from a function's signature."""
|
||||||
schema = inferred_model.schema()["properties"]
|
schema = inferred_model.schema()["properties"]
|
||||||
valid_keys = signature(func).parameters
|
valid_keys = signature(func).parameters
|
||||||
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
|
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args_from_docstring(docstring: Optional[str]) -> Dict[str, str]:
|
||||||
|
"""Parses the argument descriptions from a Google-style docstring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docstring: The docstring to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary where keys are argument names and values are their
|
||||||
|
descriptions.
|
||||||
|
"""
|
||||||
|
args_dict: Dict[str, str] = {}
|
||||||
|
|
||||||
|
if docstring and (args_section := re.search(r"Args:\n((?:\s*.+\n)+)", docstring)):
|
||||||
|
arg_lines = args_section.group(1).strip().split("\n")
|
||||||
|
else:
|
||||||
|
return args_dict
|
||||||
|
|
||||||
|
current_arg = None
|
||||||
|
current_desc = []
|
||||||
|
|
||||||
|
for line in arg_lines:
|
||||||
|
if match := re.match(r"\s*(\w+).*?:\s*(.*)", line):
|
||||||
|
if current_arg:
|
||||||
|
args_dict[current_arg] = " ".join(current_desc).strip()
|
||||||
|
current_arg = match.group(1)
|
||||||
|
if current_arg in ("Returns", "Yields", "Raises"):
|
||||||
|
current_arg = None
|
||||||
|
break
|
||||||
|
current_desc = [match.group(2).strip()]
|
||||||
|
else:
|
||||||
|
current_desc.append(line.strip())
|
||||||
|
|
||||||
|
if current_arg:
|
||||||
|
args_dict[current_arg] = " ".join(current_desc).strip()
|
||||||
|
|
||||||
|
return args_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_func_description_from_docstring(docstring: Optional[str]) -> Optional[str]:
|
||||||
|
if not docstring:
|
||||||
|
return docstring
|
||||||
|
if description_match := re.search(
|
||||||
|
r"(.*?)(?:Args|Returns|Yields|Raises)", docstring, flags=re.DOTALL
|
||||||
|
):
|
||||||
|
description = description_match.group(1)
|
||||||
|
else:
|
||||||
|
description = docstring
|
||||||
|
return " ".join(li.strip() for li in description.split("\n") if li.strip())
|
||||||
|
|
||||||
|
|
||||||
class _SchemaConfig:
|
class _SchemaConfig:
|
||||||
"""Configuration for the pydantic model."""
|
"""Configuration for the pydantic model."""
|
||||||
|
|
||||||
@@ -111,14 +163,13 @@ class _SchemaConfig:
|
|||||||
arbitrary_types_allowed: bool = True
|
arbitrary_types_allowed: bool = True
|
||||||
|
|
||||||
|
|
||||||
def create_schema_from_function(
|
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
||||||
model_name: str,
|
|
||||||
func: Callable,
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
"""Create a pydantic schema from a function's signature.
|
"""Create a pydantic schema from a function's signature.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name to assign to the generated pydandic schema
|
model_name: Name to assign to the generated pydantdic schema
|
||||||
func: Function to generate the schema from
|
func: Function to generate the schema from
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A pydantic model with the same arguments as the function
|
A pydantic model with the same arguments as the function
|
||||||
"""
|
"""
|
||||||
@@ -131,8 +182,15 @@ def create_schema_from_function(
|
|||||||
del inferred_model.__fields__["callbacks"]
|
del inferred_model.__fields__["callbacks"]
|
||||||
# Pydantic adds placeholder virtual fields we need to strip
|
# Pydantic adds placeholder virtual fields we need to strip
|
||||||
valid_properties = _get_filtered_args(inferred_model, func)
|
valid_properties = _get_filtered_args(inferred_model, func)
|
||||||
|
docstring = getattr(func, "__doc__", "")
|
||||||
|
func_description = _parse_func_description_from_docstring(docstring)
|
||||||
|
arg_descriptions = _parse_args_from_docstring(docstring)
|
||||||
return _create_subset_model(
|
return _create_subset_model(
|
||||||
f"{model_name}Schema", inferred_model, list(valid_properties)
|
f"{model_name}Schema",
|
||||||
|
inferred_model,
|
||||||
|
list(valid_properties),
|
||||||
|
description=func_description,
|
||||||
|
field_descriptions=arg_descriptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -802,7 +860,7 @@ class StructuredTool(BaseTool):
|
|||||||
description: The description of the tool. Defaults to the function docstring
|
description: The description of the tool. Defaults to the function docstring
|
||||||
return_direct: Whether to return the result directly or as a callback
|
return_direct: Whether to return the result directly or as a callback
|
||||||
args_schema: The schema of the tool's input arguments
|
args_schema: The schema of the tool's input arguments
|
||||||
infer_schema: Whether to infer the schema from the function's signature
|
infer_schema: DEPRECATED. args_schema is always inferred if not specified.
|
||||||
**kwargs: Additional arguments to pass to the tool
|
**kwargs: Additional arguments to pass to the tool
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -826,29 +884,24 @@ class StructuredTool(BaseTool):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Function and/or coroutine must be provided")
|
raise ValueError("Function and/or coroutine must be provided")
|
||||||
name = name or source_function.__name__
|
name = name or source_function.__name__
|
||||||
description_ = description or source_function.__doc__
|
inferred_schema = create_schema_from_function(name, source_function)
|
||||||
if description_ is None:
|
args_schema = args_schema or inferred_schema
|
||||||
raise ValueError(
|
description = (
|
||||||
"Function must have a docstring if description not provided."
|
description
|
||||||
)
|
or args_schema.schema().get("description")
|
||||||
|
or inferred_schema.schema().get("description")
|
||||||
|
)
|
||||||
if description is None:
|
if description is None:
|
||||||
# Only apply if using the function's docstring
|
raise ValueError(
|
||||||
description_ = textwrap.dedent(description_).strip()
|
"Must specify a description or pass in an args_schema or function with "
|
||||||
|
"a docstring."
|
||||||
# Description example:
|
)
|
||||||
# search_api(query: str) - Searches the API for the query.
|
|
||||||
sig = signature(source_function)
|
|
||||||
description_ = f"{name}{sig} - {description_.strip()}"
|
|
||||||
_args_schema = args_schema
|
|
||||||
if _args_schema is None and infer_schema:
|
|
||||||
# schema name is appended within function
|
|
||||||
_args_schema = create_schema_from_function(name, source_function)
|
|
||||||
return cls(
|
return cls(
|
||||||
name=name,
|
name=name,
|
||||||
func=func,
|
func=func,
|
||||||
coroutine=coroutine,
|
coroutine=coroutine,
|
||||||
args_schema=_args_schema, # type: ignore[arg-type]
|
args_schema=args_schema,
|
||||||
description=description_,
|
description=description,
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -24,6 +23,9 @@ from langchain_core.tools import (
|
|||||||
Tool,
|
Tool,
|
||||||
ToolException,
|
ToolException,
|
||||||
_create_subset_model,
|
_create_subset_model,
|
||||||
|
_parse_args_from_docstring,
|
||||||
|
_parse_func_description_from_docstring,
|
||||||
|
create_schema_from_function,
|
||||||
tool,
|
tool,
|
||||||
)
|
)
|
||||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||||
@@ -318,23 +320,22 @@ def test_structured_tool_from_function_docstring() -> None:
|
|||||||
structured_tool = StructuredTool.from_function(foo)
|
structured_tool = StructuredTool.from_function(foo)
|
||||||
assert structured_tool.name == "foo"
|
assert structured_tool.name == "foo"
|
||||||
assert structured_tool.args == {
|
assert structured_tool.args == {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert structured_tool.args_schema.schema() == {
|
assert structured_tool.args_schema.schema() == {
|
||||||
"properties": {
|
"properties": {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||||
},
|
},
|
||||||
"title": "fooSchema",
|
"title": "fooSchema",
|
||||||
|
"description": "Docstring",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["bar", "baz"],
|
"required": ["bar", "baz"],
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix = "foo(bar: int, baz: str) -> str - "
|
assert structured_tool.description == "Docstring"
|
||||||
assert foo.__doc__ is not None
|
|
||||||
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
|
|
||||||
|
|
||||||
|
|
||||||
def test_structured_tool_from_function_docstring_complex_args() -> None:
|
def test_structured_tool_from_function_docstring_complex_args() -> None:
|
||||||
@@ -351,23 +352,32 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
|||||||
structured_tool = StructuredTool.from_function(foo)
|
structured_tool = StructuredTool.from_function(foo)
|
||||||
assert structured_tool.name == "foo"
|
assert structured_tool.name == "foo"
|
||||||
assert structured_tool.args == {
|
assert structured_tool.args == {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
"baz": {
|
||||||
|
"title": "Baz",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "List[str]",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert structured_tool.args_schema.schema() == {
|
assert structured_tool.args_schema.schema() == {
|
||||||
"properties": {
|
"properties": {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
"baz": {
|
||||||
|
"title": "Baz",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "List[str]",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"title": "fooSchema",
|
"title": "fooSchema",
|
||||||
|
"description": "Docstring",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["bar", "baz"],
|
"required": ["bar", "baz"],
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix = "foo(bar: int, baz: List[str]) -> str - "
|
assert structured_tool.description == "Docstring"
|
||||||
assert foo.__doc__ is not None
|
|
||||||
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_structured_tool_lambda_multi_args_schema() -> None:
|
def test_structured_tool_lambda_multi_args_schema() -> None:
|
||||||
@@ -451,16 +461,17 @@ def test_structured_tool_from_function_with_run_manager() -> None:
|
|||||||
structured_tool = StructuredTool.from_function(foo)
|
structured_tool = StructuredTool.from_function(foo)
|
||||||
|
|
||||||
assert structured_tool.args == {
|
assert structured_tool.args == {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert structured_tool.args_schema.schema() == {
|
assert structured_tool.args_schema.schema() == {
|
||||||
"properties": {
|
"properties": {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||||
},
|
},
|
||||||
"title": "fooSchema",
|
"title": "fooSchema",
|
||||||
|
"description": "Docstring",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["bar", "baz"],
|
"required": ["bar", "baz"],
|
||||||
}
|
}
|
||||||
@@ -553,7 +564,7 @@ def test_tool_with_kwargs() -> None:
|
|||||||
def test_missing_docstring() -> None:
|
def test_missing_docstring() -> None:
|
||||||
"""Test error is raised when docstring is missing."""
|
"""Test error is raised when docstring is missing."""
|
||||||
# expect to throw a value error if there's no docstring
|
# expect to throw a value error if there's no docstring
|
||||||
with pytest.raises(ValueError, match="Function must have a docstring"):
|
with pytest.raises(ValueError):
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def search_api(query: str) -> str:
|
def search_api(query: str) -> str:
|
||||||
@@ -686,23 +697,22 @@ def test_structured_tool_from_function() -> None:
|
|||||||
structured_tool = StructuredTool.from_function(foo)
|
structured_tool = StructuredTool.from_function(foo)
|
||||||
assert structured_tool.name == "foo"
|
assert structured_tool.name == "foo"
|
||||||
assert structured_tool.args == {
|
assert structured_tool.args == {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert structured_tool.args_schema.schema() == {
|
assert structured_tool.args_schema.schema() == {
|
||||||
"title": "fooSchema",
|
"title": "fooSchema",
|
||||||
|
"description": "Docstring",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||||
},
|
},
|
||||||
"required": ["bar", "baz"],
|
"required": ["bar", "baz"],
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix = "foo(bar: int, baz: str) -> str - "
|
assert structured_tool.description == "Docstring"
|
||||||
assert foo.__doc__ is not None
|
|
||||||
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
|
|
||||||
|
|
||||||
|
|
||||||
def test_validation_error_handling_bool() -> None:
|
def test_validation_error_handling_bool() -> None:
|
||||||
@@ -906,3 +916,165 @@ async def test_async_tool_pass_context() -> None:
|
|||||||
assert (
|
assert (
|
||||||
await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
|
await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"docstring",
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
"",
|
||||||
|
"""
|
||||||
|
A function without an args section.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
""",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_docstring_no_args_section(docstring: Optional[str]) -> None:
|
||||||
|
assert _parse_args_from_docstring(docstring) == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_docstring_single_argument() -> None:
|
||||||
|
docstring = """
|
||||||
|
A function with a single argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param1: The first parameter.
|
||||||
|
"""
|
||||||
|
expected = {"param1": "The first parameter."}
|
||||||
|
assert _parse_args_from_docstring(docstring) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"docstring",
|
||||||
|
[
|
||||||
|
"""
|
||||||
|
A function with multiple arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param1: The first parameter.
|
||||||
|
param2: The second parameter.
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
A function with multiline argument descriptions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param1: The first parameter.
|
||||||
|
param2: The second
|
||||||
|
parameter.
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
A function with the args section that has blank lines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param1: The first parameter.
|
||||||
|
|
||||||
|
param2: The second parameter.
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
A function with extra sections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param1: The first parameter.
|
||||||
|
param2: The second parameter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
foobar
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
barfoo
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
baz
|
||||||
|
""",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_docstring_multiple_arguments(docstring: str) -> None:
|
||||||
|
expected = {
|
||||||
|
"param1": "The first parameter.",
|
||||||
|
"param2": "The second parameter.",
|
||||||
|
}
|
||||||
|
assert _parse_args_from_docstring(docstring) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_docstring_args_with_multiple_colons_in_single_line() -> None:
|
||||||
|
docstring = """
|
||||||
|
A function with a colon in the description.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param1: The first parameter: with colon.
|
||||||
|
param2: The second parameter.
|
||||||
|
"""
|
||||||
|
expected = {
|
||||||
|
"param1": "The first parameter: with colon.",
|
||||||
|
"param2": "The second parameter.",
|
||||||
|
}
|
||||||
|
assert _parse_args_from_docstring(docstring) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_docstring_description() -> None:
|
||||||
|
docstring = """
|
||||||
|
A function with a
|
||||||
|
multiline description.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
_parse_func_description_from_docstring(docstring)
|
||||||
|
== "A function with a multiline description."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("section", ["Args", "Returns", "Yields", "Raises"])
|
||||||
|
def test_parse_docstring_description_multiple_sections(section: str) -> None:
|
||||||
|
docstring = f"""
|
||||||
|
A function with a
|
||||||
|
multiline description.
|
||||||
|
|
||||||
|
{section}:
|
||||||
|
foo: bar
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
_parse_func_description_from_docstring(docstring)
|
||||||
|
== "A function with a multiline description."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("docstring", [None, "", "\n Args:\n bar"])
|
||||||
|
def test_parse_docstring_description_no_description(docstring: Optional[str]) -> None:
|
||||||
|
assert not _parse_func_description_from_docstring(docstring)
|
||||||
|
|
||||||
|
|
||||||
|
def foo1(a: int, b: str = "") -> float:
|
||||||
|
"""
|
||||||
|
do foo
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (int) : this
|
||||||
|
describes
|
||||||
|
a
|
||||||
|
b: this describes b
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
blah
|
||||||
|
"""
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_schema_from_function() -> None:
|
||||||
|
expected = {
|
||||||
|
"title": "fooSchema",
|
||||||
|
"description": "do foo",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"title": "A", "description": "this describes a", "type": "integer"},
|
||||||
|
"b": {
|
||||||
|
"title": "B",
|
||||||
|
"description": "this describes b",
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["a"],
|
||||||
|
}
|
||||||
|
actual = create_schema_from_function("foo", foo1).schema()
|
||||||
|
assert expected == actual
|
||||||
|
|||||||
@@ -53,6 +53,27 @@ def dummy_tool() -> BaseTool:
|
|||||||
return DummyFunction()
|
return DummyFunction()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def dummy_tool_from_function() -> BaseTool:
|
||||||
|
@tool()
|
||||||
|
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||||
|
"""dummy function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arg1: foo
|
||||||
|
arg2: one of 'bar', 'baz'
|
||||||
|
|
||||||
|
Return:
|
||||||
|
blah
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
bleh
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def json_schema() -> Dict:
|
def json_schema() -> Dict:
|
||||||
return {
|
return {
|
||||||
@@ -98,6 +119,7 @@ def test_convert_to_openai_function(
|
|||||||
pydantic: Type[BaseModel],
|
pydantic: Type[BaseModel],
|
||||||
function: Callable,
|
function: Callable,
|
||||||
dummy_tool: BaseTool,
|
dummy_tool: BaseTool,
|
||||||
|
dummy_tool_from_function: BaseTool,
|
||||||
json_schema: Dict,
|
json_schema: Dict,
|
||||||
) -> None:
|
) -> None:
|
||||||
expected = {
|
expected = {
|
||||||
@@ -121,6 +143,7 @@ def test_convert_to_openai_function(
|
|||||||
pydantic,
|
pydantic,
|
||||||
function,
|
function,
|
||||||
dummy_tool,
|
dummy_tool,
|
||||||
|
dummy_tool_from_function,
|
||||||
json_schema,
|
json_schema,
|
||||||
expected,
|
expected,
|
||||||
Dummy.dummy_function,
|
Dummy.dummy_function,
|
||||||
|
|||||||
Reference in New Issue
Block a user