Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
b486b1e315 [Core] Tool function parsing
Support list, dict, set
2024-06-21 15:47:17 -07:00
2 changed files with 58 additions and 6 deletions

View File

@@ -40,6 +40,10 @@ PYTHON_TO_JSON_TYPES = {
"int": "integer",
"float": "number",
"bool": "boolean",
"list": "array",
"tuple": "array",
"dict": "object",
"set": {"type": "array", "uniqueItems": True},
}
@@ -185,6 +189,7 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
if isinstance(annotation, str):
arg_descriptions[arg] = annotation
break
origin = get_origin(arg_type)
if (
isinstance(arg_type, type)
and hasattr(arg_type, "model_json_schema")
@@ -197,11 +202,12 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
and callable(arg_type.schema)
):
properties[arg] = arg_type.schema()
elif (
hasattr(arg_type, "__name__")
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
):
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
elif getattr(arg_type, "__name__", None) in PYTHON_TO_JSON_TYPES:
typ_ = PYTHON_TO_JSON_TYPES[arg_type.__name__]
properties[arg] = {"type": typ_} if isinstance(typ_, str) else typ_
elif getattr(origin, "__name__", None) in PYTHON_TO_JSON_TYPES:
typ_ = PYTHON_TO_JSON_TYPES[getattr(origin, "__name__")]
properties[arg] = {"type": typ_} if isinstance(typ_, str) else typ_
elif (
hasattr(arg_type, "__dict__")
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
@@ -211,6 +217,7 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__],
}
else:
# TODO We don't handle optionals or union types well here.
logger.warning(
f"Argument {arg} of type {arg_type} from function {function.__name__} "
"could not be not be converted to a JSON schema."

View File

@@ -1,5 +1,5 @@
# mypy: disable-error-code="annotation-unchecked"
from typing import Any, Callable, Dict, List, Literal, Optional, Type
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Type
import pytest
from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore
@@ -230,6 +230,51 @@ def test_convert_to_openai_function_nested() -> None:
assert actual == expected
def test_convert_list_dict_tuple() -> None:
def my_function(
arg1: list,
arg2: dict,
arg3: tuple,
arg4: List,
arg5: Dict,
arg6: Tuple,
arg7: set,
arg8: Set,
) -> None:
"""dummy function"""
pass
expected = {
"name": "my_function",
"description": "dummy function",
"parameters": {
"type": "object",
"properties": {
"arg1": {"type": "array"},
"arg2": {"type": "object"},
"arg3": {"type": "array"},
"arg4": {"type": "array"},
"arg5": {"type": "object"},
"arg6": {"type": "array"},
"arg7": {"type": "array", "uniqueItems": True},
"arg8": {"type": "array", "uniqueItems": True},
},
"required": [
"arg1",
"arg2",
"arg3",
"arg4",
"arg5",
"arg6",
"arg7",
"arg8",
],
},
}
actual = convert_to_openai_function(my_function)
assert actual == expected
@pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()")
def test_function_optional_param() -> None:
@tool