mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-14 06:44:16 +00:00
Compare commits
3 Commits
langchain-
...
bagatur/pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49b68a0cb2 | ||
|
|
8716d065a3 | ||
|
|
c15a541ccd |
@@ -21,7 +21,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import textwrap
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -37,8 +37,6 @@ from langchain_core.callbacks import (
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.load.serializable import Serializable
|
||||
@@ -77,7 +75,12 @@ class SchemaAnnotationError(TypeError):
|
||||
|
||||
|
||||
def _create_subset_model(
|
||||
name: str, model: Type[BaseModel], field_names: list
|
||||
name: str,
|
||||
model: Type[BaseModel],
|
||||
field_names: list,
|
||||
*,
|
||||
description: Optional[str] = None,
|
||||
field_descriptions: Optional[Dict[str, str]] = None,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
fields = {}
|
||||
@@ -89,21 +92,70 @@ def _create_subset_model(
|
||||
if field.required and not field.allow_none
|
||||
else Optional[field.outer_type_]
|
||||
)
|
||||
if field_descriptions and field_name in field_descriptions:
|
||||
field.field_info.description = field_descriptions[field_name]
|
||||
fields[field_name] = (t, field.field_info)
|
||||
rtn = create_model(name, **fields) # type: ignore
|
||||
rtn = create_model(name, __doc__=description, **fields) # type: ignore
|
||||
return rtn
|
||||
|
||||
|
||||
def _get_filtered_args(
|
||||
inferred_model: Type[BaseModel],
|
||||
func: Callable,
|
||||
) -> dict:
|
||||
def _get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
schema = inferred_model.schema()["properties"]
|
||||
valid_keys = signature(func).parameters
|
||||
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
|
||||
|
||||
|
||||
def _parse_args_from_docstring(docstring: Optional[str]) -> Dict[str, str]:
|
||||
"""Parses the argument descriptions from a Google-style docstring.
|
||||
|
||||
Args:
|
||||
docstring: The docstring to parse.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are argument names and values are their
|
||||
descriptions.
|
||||
"""
|
||||
args_dict: Dict[str, str] = {}
|
||||
|
||||
if docstring and (args_section := re.search(r"Args:\n((?:\s*.+\n)+)", docstring)):
|
||||
arg_lines = args_section.group(1).strip().split("\n")
|
||||
else:
|
||||
return args_dict
|
||||
|
||||
current_arg = None
|
||||
current_desc = []
|
||||
|
||||
for line in arg_lines:
|
||||
if match := re.match(r"\s*(\w+).*?:\s*(.*)", line):
|
||||
if current_arg:
|
||||
args_dict[current_arg] = " ".join(current_desc).strip()
|
||||
current_arg = match.group(1)
|
||||
if current_arg in ("Returns", "Yields", "Raises"):
|
||||
current_arg = None
|
||||
break
|
||||
current_desc = [match.group(2).strip()]
|
||||
else:
|
||||
current_desc.append(line.strip())
|
||||
|
||||
if current_arg:
|
||||
args_dict[current_arg] = " ".join(current_desc).strip()
|
||||
|
||||
return args_dict
|
||||
|
||||
|
||||
def _parse_func_description_from_docstring(docstring: Optional[str]) -> Optional[str]:
|
||||
if not docstring:
|
||||
return docstring
|
||||
if description_match := re.search(
|
||||
r"(.*?)(?:Args|Returns|Yields|Raises)", docstring, flags=re.DOTALL
|
||||
):
|
||||
description = description_match.group(1)
|
||||
else:
|
||||
description = docstring
|
||||
return " ".join(li.strip() for li in description.split("\n") if li.strip())
|
||||
|
||||
|
||||
class _SchemaConfig:
|
||||
"""Configuration for the pydantic model."""
|
||||
|
||||
@@ -111,14 +163,13 @@ class _SchemaConfig:
|
||||
arbitrary_types_allowed: bool = True
|
||||
|
||||
|
||||
def create_schema_from_function(
|
||||
model_name: str,
|
||||
func: Callable,
|
||||
) -> Type[BaseModel]:
|
||||
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature.
|
||||
|
||||
Args:
|
||||
model_name: Name to assign to the generated pydandic schema
|
||||
model_name: Name to assign to the generated pydantdic schema
|
||||
func: Function to generate the schema from
|
||||
|
||||
Returns:
|
||||
A pydantic model with the same arguments as the function
|
||||
"""
|
||||
@@ -131,8 +182,15 @@ def create_schema_from_function(
|
||||
del inferred_model.__fields__["callbacks"]
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
valid_properties = _get_filtered_args(inferred_model, func)
|
||||
docstring = getattr(func, "__doc__", "")
|
||||
func_description = _parse_func_description_from_docstring(docstring)
|
||||
arg_descriptions = _parse_args_from_docstring(docstring)
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema", inferred_model, list(valid_properties)
|
||||
f"{model_name}Schema",
|
||||
inferred_model,
|
||||
list(valid_properties),
|
||||
description=func_description,
|
||||
field_descriptions=arg_descriptions,
|
||||
)
|
||||
|
||||
|
||||
@@ -802,7 +860,7 @@ class StructuredTool(BaseTool):
|
||||
description: The description of the tool. Defaults to the function docstring
|
||||
return_direct: Whether to return the result directly or as a callback
|
||||
args_schema: The schema of the tool's input arguments
|
||||
infer_schema: Whether to infer the schema from the function's signature
|
||||
infer_schema: DEPRECATED. args_schema is always inferred if not specified.
|
||||
**kwargs: Additional arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
@@ -826,29 +884,24 @@ class StructuredTool(BaseTool):
|
||||
else:
|
||||
raise ValueError("Function and/or coroutine must be provided")
|
||||
name = name or source_function.__name__
|
||||
description_ = description or source_function.__doc__
|
||||
if description_ is None:
|
||||
raise ValueError(
|
||||
"Function must have a docstring if description not provided."
|
||||
)
|
||||
inferred_schema = create_schema_from_function(name, source_function)
|
||||
args_schema = args_schema or inferred_schema
|
||||
description = (
|
||||
description
|
||||
or args_schema.schema().get("description")
|
||||
or inferred_schema.schema().get("description")
|
||||
)
|
||||
if description is None:
|
||||
# Only apply if using the function's docstring
|
||||
description_ = textwrap.dedent(description_).strip()
|
||||
|
||||
# Description example:
|
||||
# search_api(query: str) - Searches the API for the query.
|
||||
sig = signature(source_function)
|
||||
description_ = f"{name}{sig} - {description_.strip()}"
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
# schema name is appended within function
|
||||
_args_schema = create_schema_from_function(name, source_function)
|
||||
raise ValueError(
|
||||
"Must specify a description or pass in an args_schema or function with "
|
||||
"a docstring."
|
||||
)
|
||||
return cls(
|
||||
name=name,
|
||||
func=func,
|
||||
coroutine=coroutine,
|
||||
args_schema=_args_schema, # type: ignore[arg-type]
|
||||
description=description_,
|
||||
args_schema=args_schema,
|
||||
description=description,
|
||||
return_direct=return_direct,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
@@ -24,6 +23,9 @@ from langchain_core.tools import (
|
||||
Tool,
|
||||
ToolException,
|
||||
_create_subset_model,
|
||||
_parse_args_from_docstring,
|
||||
_parse_func_description_from_docstring,
|
||||
create_schema_from_function,
|
||||
tool,
|
||||
)
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
@@ -318,23 +320,22 @@ def test_structured_tool_from_function_docstring() -> None:
|
||||
structured_tool = StructuredTool.from_function(foo)
|
||||
assert structured_tool.name == "foo"
|
||||
assert structured_tool.args == {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||
},
|
||||
"title": "fooSchema",
|
||||
"description": "Docstring",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
|
||||
prefix = "foo(bar: int, baz: str) -> str - "
|
||||
assert foo.__doc__ is not None
|
||||
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
|
||||
assert structured_tool.description == "Docstring"
|
||||
|
||||
|
||||
def test_structured_tool_from_function_docstring_complex_args() -> None:
|
||||
@@ -351,23 +352,32 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
||||
structured_tool = StructuredTool.from_function(foo)
|
||||
assert structured_tool.name == "foo"
|
||||
assert structured_tool.args == {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
||||
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||
"baz": {
|
||||
"title": "Baz",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List[str]",
|
||||
},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
||||
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||
"baz": {
|
||||
"title": "Baz",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List[str]",
|
||||
},
|
||||
},
|
||||
"title": "fooSchema",
|
||||
"description": "Docstring",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
|
||||
prefix = "foo(bar: int, baz: List[str]) -> str - "
|
||||
assert foo.__doc__ is not None
|
||||
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__).strip()
|
||||
assert structured_tool.description == "Docstring"
|
||||
|
||||
|
||||
def test_structured_tool_lambda_multi_args_schema() -> None:
|
||||
@@ -451,16 +461,17 @@ def test_structured_tool_from_function_with_run_manager() -> None:
|
||||
structured_tool = StructuredTool.from_function(foo)
|
||||
|
||||
assert structured_tool.args == {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||
},
|
||||
"title": "fooSchema",
|
||||
"description": "Docstring",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
@@ -553,7 +564,7 @@ def test_tool_with_kwargs() -> None:
|
||||
def test_missing_docstring() -> None:
|
||||
"""Test error is raised when docstring is missing."""
|
||||
# expect to throw a value error if there's no docstring
|
||||
with pytest.raises(ValueError, match="Function must have a docstring"):
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
@@ -686,23 +697,22 @@ def test_structured_tool_from_function() -> None:
|
||||
structured_tool = StructuredTool.from_function(foo)
|
||||
assert structured_tool.name == "foo"
|
||||
assert structured_tool.args == {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
"title": "fooSchema",
|
||||
"description": "Docstring",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
"bar": {"title": "Bar", "type": "integer", "description": "int"},
|
||||
"baz": {"title": "Baz", "type": "string", "description": "str"},
|
||||
},
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
|
||||
prefix = "foo(bar: int, baz: str) -> str - "
|
||||
assert foo.__doc__ is not None
|
||||
assert structured_tool.description == prefix + textwrap.dedent(foo.__doc__.strip())
|
||||
assert structured_tool.description == "Docstring"
|
||||
|
||||
|
||||
def test_validation_error_handling_bool() -> None:
|
||||
@@ -906,3 +916,165 @@ async def test_async_tool_pass_context() -> None:
|
||||
assert (
|
||||
await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"docstring",
|
||||
[
|
||||
None,
|
||||
"",
|
||||
"""
|
||||
A function without an args section.
|
||||
|
||||
Returns:
|
||||
None
|
||||
""",
|
||||
],
|
||||
)
|
||||
def test_parse_docstring_no_args_section(docstring: Optional[str]) -> None:
|
||||
assert _parse_args_from_docstring(docstring) == {}
|
||||
|
||||
|
||||
def test_parse_docstring_single_argument() -> None:
|
||||
docstring = """
|
||||
A function with a single argument.
|
||||
|
||||
Args:
|
||||
param1: The first parameter.
|
||||
"""
|
||||
expected = {"param1": "The first parameter."}
|
||||
assert _parse_args_from_docstring(docstring) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"docstring",
|
||||
[
|
||||
"""
|
||||
A function with multiple arguments.
|
||||
|
||||
Args:
|
||||
param1: The first parameter.
|
||||
param2: The second parameter.
|
||||
""",
|
||||
"""
|
||||
A function with multiline argument descriptions.
|
||||
|
||||
Args:
|
||||
param1: The first parameter.
|
||||
param2: The second
|
||||
parameter.
|
||||
""",
|
||||
"""
|
||||
A function with the args section that has blank lines.
|
||||
|
||||
Args:
|
||||
param1: The first parameter.
|
||||
|
||||
param2: The second parameter.
|
||||
""",
|
||||
"""
|
||||
A function with extra sections.
|
||||
|
||||
Args:
|
||||
param1: The first parameter.
|
||||
param2: The second parameter.
|
||||
|
||||
Returns:
|
||||
foobar
|
||||
|
||||
Yields:
|
||||
barfoo
|
||||
|
||||
Raises:
|
||||
baz
|
||||
""",
|
||||
],
|
||||
)
|
||||
def test_parse_docstring_multiple_arguments(docstring: str) -> None:
|
||||
expected = {
|
||||
"param1": "The first parameter.",
|
||||
"param2": "The second parameter.",
|
||||
}
|
||||
assert _parse_args_from_docstring(docstring) == expected
|
||||
|
||||
|
||||
def test_parse_docstring_args_with_multiple_colons_in_single_line() -> None:
|
||||
docstring = """
|
||||
A function with a colon in the description.
|
||||
|
||||
Args:
|
||||
param1: The first parameter: with colon.
|
||||
param2: The second parameter.
|
||||
"""
|
||||
expected = {
|
||||
"param1": "The first parameter: with colon.",
|
||||
"param2": "The second parameter.",
|
||||
}
|
||||
assert _parse_args_from_docstring(docstring) == expected
|
||||
|
||||
|
||||
def test_parse_docstring_description() -> None:
|
||||
docstring = """
|
||||
A function with a
|
||||
multiline description.
|
||||
"""
|
||||
assert (
|
||||
_parse_func_description_from_docstring(docstring)
|
||||
== "A function with a multiline description."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("section", ["Args", "Returns", "Yields", "Raises"])
|
||||
def test_parse_docstring_description_multiple_sections(section: str) -> None:
|
||||
docstring = f"""
|
||||
A function with a
|
||||
multiline description.
|
||||
|
||||
{section}:
|
||||
foo: bar
|
||||
"""
|
||||
assert (
|
||||
_parse_func_description_from_docstring(docstring)
|
||||
== "A function with a multiline description."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("docstring", [None, "", "\n Args:\n bar"])
|
||||
def test_parse_docstring_description_no_description(docstring: Optional[str]) -> None:
|
||||
assert not _parse_func_description_from_docstring(docstring)
|
||||
|
||||
|
||||
def foo1(a: int, b: str = "") -> float:
|
||||
"""
|
||||
do foo
|
||||
|
||||
Args:
|
||||
a (int) : this
|
||||
describes
|
||||
a
|
||||
b: this describes b
|
||||
|
||||
Returns:
|
||||
blah
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
|
||||
def test_create_schema_from_function() -> None:
|
||||
expected = {
|
||||
"title": "fooSchema",
|
||||
"description": "do foo",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"title": "A", "description": "this describes a", "type": "integer"},
|
||||
"b": {
|
||||
"title": "B",
|
||||
"description": "this describes b",
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
},
|
||||
"required": ["a"],
|
||||
}
|
||||
actual = create_schema_from_function("foo", foo1).schema()
|
||||
assert expected == actual
|
||||
|
||||
@@ -53,6 +53,27 @@ def dummy_tool() -> BaseTool:
|
||||
return DummyFunction()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_tool_from_function() -> BaseTool:
|
||||
@tool()
|
||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||
"""dummy function
|
||||
|
||||
Args:
|
||||
arg1: foo
|
||||
arg2: one of 'bar', 'baz'
|
||||
|
||||
Return:
|
||||
blah
|
||||
|
||||
Raises:
|
||||
bleh
|
||||
"""
|
||||
return
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def json_schema() -> Dict:
|
||||
return {
|
||||
@@ -98,6 +119,7 @@ def test_convert_to_openai_function(
|
||||
pydantic: Type[BaseModel],
|
||||
function: Callable,
|
||||
dummy_tool: BaseTool,
|
||||
dummy_tool_from_function: BaseTool,
|
||||
json_schema: Dict,
|
||||
) -> None:
|
||||
expected = {
|
||||
@@ -121,6 +143,7 @@ def test_convert_to_openai_function(
|
||||
pydantic,
|
||||
function,
|
||||
dummy_tool,
|
||||
dummy_tool_from_function,
|
||||
json_schema,
|
||||
expected,
|
||||
Dummy.dummy_function,
|
||||
|
||||
Reference in New Issue
Block a user