core[patch], integrations[patch]: convert TypedDict to tool schema support (#24641)

supports following UX

```python
    class SubTool(TypedDict):
        """Subtool docstring"""

        args: Annotated[Dict[str, Any], {}, "this does bar"]

    class Tool(TypedDict):
        """Docstring
        Args:
            arg1: foo
        """

        arg1: str
        arg2: Union[int, str]
        arg3: Optional[List[SubTool]]
        arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"]
        arg5: Annotated[Optional[float], None]
```

- can parse google style docstring
- can use Annotated to specify default value (second arg)
- can use Annotated to specify arg description (third arg)
- can have nested complex types
This commit is contained in:
Bagatur
2024-07-31 11:27:24 -07:00
committed by GitHub
parent d24b82357f
commit 8461934c2b
17 changed files with 1371 additions and 468 deletions

View File

@@ -1,16 +1,44 @@
# mypy: disable-error-code="annotation-unchecked"
from typing import Any, Callable, Dict, List, Literal, Optional, Type
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Mapping,
MutableMapping,
MutableSet,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
from typing import TypedDict as TypingTypedDict
import pytest
from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore
from pydantic import Field as FieldV2Maybe # pydantic: ignore
from typing_extensions import Annotated, TypedDict
from typing_extensions import (
Annotated as ExtensionsAnnotated,
)
from typing_extensions import (
TypedDict as ExtensionsTypedDict,
)
try:
from typing import Annotated as TypingAnnotated # type: ignore[attr-defined]
except ImportError:
TypingAnnotated = ExtensionsAnnotated
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.tools import BaseTool, tool
from langchain_core.utils.function_calling import (
_convert_typed_dict_to_openai_function,
convert_to_openai_function,
tool_example_to_messages,
)
@@ -28,10 +56,10 @@ def pydantic() -> Type[BaseModel]:
@pytest.fixture()
def annotated_function() -> Callable:
def Annotated_function() -> Callable:
def dummy_function(
arg1: Annotated[int, "foo"],
arg2: Annotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
arg1: ExtensionsAnnotated[int, "foo"],
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
) -> None:
"""dummy function"""
pass
@@ -55,9 +83,9 @@ def function() -> Callable:
@pytest.fixture()
def runnable() -> Runnable:
class Args(TypedDict):
arg1: Annotated[int, "foo"]
arg2: Annotated[Literal["bar", "baz"], "one of 'bar', 'baz'"]
class Args(ExtensionsTypedDict):
arg1: ExtensionsAnnotated[int, "foo"]
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"]
def dummy_function(input_dict: Args) -> None:
pass
@@ -106,6 +134,60 @@ def dummy_pydantic_v2() -> Type[BaseModelV2Maybe]:
return dummy_function
@pytest.fixture()
def dummy_typing_typed_dict() -> Type:
class dummy_function(TypingTypedDict):
"""dummy function"""
arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821
arg2: TypingAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"] # noqa: F722
return dummy_function
@pytest.fixture()
def dummy_typing_typed_dict_docstring() -> Type:
class dummy_function(TypingTypedDict):
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
"""
arg1: int
arg2: Literal["bar", "baz"]
return dummy_function
@pytest.fixture()
def dummy_extensions_typed_dict() -> Type:
class dummy_function(ExtensionsTypedDict):
"""dummy function"""
arg1: ExtensionsAnnotated[int, ..., "foo"]
arg2: ExtensionsAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"]
return dummy_function
@pytest.fixture()
def dummy_extensions_typed_dict_docstring() -> Type:
class dummy_function(ExtensionsTypedDict):
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
"""
arg1: int
arg2: Literal["bar", "baz"]
return dummy_function
@pytest.fixture()
def json_schema() -> Dict:
return {
@@ -152,9 +234,13 @@ def test_convert_to_openai_function(
function: Callable,
dummy_tool: BaseTool,
json_schema: Dict,
annotated_function: Callable,
Annotated_function: Callable,
dummy_pydantic: Type[BaseModel],
runnable: Runnable,
dummy_typing_typed_dict: Type,
dummy_typing_typed_dict_docstring: Type,
dummy_extensions_typed_dict: Type,
dummy_extensions_typed_dict_docstring: Type,
) -> None:
expected = {
"name": "dummy_function",
@@ -181,8 +267,12 @@ def test_convert_to_openai_function(
expected,
Dummy.dummy_function,
DummyWithClassMethod.dummy_function,
annotated_function,
Annotated_function,
dummy_pydantic,
dummy_typing_typed_dict,
dummy_typing_typed_dict_docstring,
dummy_extensions_typed_dict,
dummy_extensions_typed_dict_docstring,
):
actual = convert_to_openai_function(fn) # type: ignore
assert actual == expected
@@ -356,3 +446,259 @@ def test_tool_outputs() -> None:
},
]
assert messages[2].content == "Output1"
@pytest.mark.parametrize("use_extension_typed_dict", [True, False])
@pytest.mark.parametrize("use_extension_annotated", [True, False])
def test__convert_typed_dict_to_openai_function(
use_extension_typed_dict: bool, use_extension_annotated: bool
) -> None:
if use_extension_typed_dict:
TypedDict = ExtensionsTypedDict
else:
TypedDict = TypingTypedDict
if use_extension_annotated:
Annotated = TypingAnnotated
else:
Annotated = TypingAnnotated
class SubTool(TypedDict):
"""Subtool docstring"""
args: Annotated[Dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
class Tool(TypedDict):
"""Docstring
Args:
arg1: foo
"""
arg1: str
arg2: Union[int, str, bool]
arg3: Optional[List[SubTool]]
arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722
arg5: Annotated[Optional[float], None]
arg6: Annotated[
Optional[Sequence[Mapping[str, Tuple[Iterable[Any], SubTool]]]], []
]
arg7: Annotated[List[SubTool], ...]
arg8: Annotated[Tuple[SubTool], ...]
arg9: Annotated[Sequence[SubTool], ...]
arg10: Annotated[Iterable[SubTool], ...]
arg11: Annotated[Set[SubTool], ...]
arg12: Annotated[Dict[str, SubTool], ...]
arg13: Annotated[Mapping[str, SubTool], ...]
arg14: Annotated[MutableMapping[str, SubTool], ...]
arg15: Annotated[bool, False, "flag"] # noqa: F821 # type: ignore
expected = {
"name": "Tool",
"description": "Docstring",
"parameters": {
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "string"},
"arg2": {
"anyOf": [
{"type": "integer"},
{"type": "string"},
{"type": "boolean"},
]
},
"arg3": {
"type": "array",
"items": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg4": {
"description": "this does foo",
"enum": ["bar", "baz"],
"type": "string",
},
"arg5": {"type": "number"},
"arg6": {
"default": [],
"type": "array",
"items": {
"type": "object",
"additionalProperties": {
"type": "array",
"minItems": 2,
"maxItems": 2,
"items": [
{"type": "array", "items": {}},
{
"title": "SubTool",
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"title": "Args",
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
],
},
},
},
"arg7": {
"type": "array",
"items": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg8": {
"type": "array",
"minItems": 1,
"maxItems": 1,
"items": [
{
"title": "SubTool",
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"title": "Args",
"description": "this does bar",
"default": {},
"type": "object",
}
},
}
],
},
"arg9": {
"type": "array",
"items": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg10": {
"type": "array",
"items": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg11": {
"type": "array",
"items": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
"uniqueItems": True,
},
"arg12": {
"type": "object",
"additionalProperties": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg13": {
"type": "object",
"additionalProperties": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg14": {
"type": "object",
"additionalProperties": {
"description": "Subtool docstring",
"type": "object",
"properties": {
"args": {
"description": "this does bar",
"default": {},
"type": "object",
}
},
},
},
"arg15": {"description": "flag", "default": False, "type": "boolean"},
},
"required": [
"arg1",
"arg2",
"arg3",
"arg4",
"arg7",
"arg8",
"arg9",
"arg10",
"arg11",
"arg12",
"arg13",
"arg14",
],
},
}
actual = _convert_typed_dict_to_openai_function(Tool)
assert actual == expected
@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
def test__convert_typed_dict_to_openai_function_fail(typed_dict: Type) -> None:
class Tool(typed_dict):
arg1: MutableSet # Pydantic doesn't support
with pytest.raises(TypeError):
_convert_typed_dict_to_openai_function(Tool)