Compare commits

...

2 Commits

Author SHA1 Message Date
Tat Dat Duong
7b633b094f Code review 2024-03-27 17:17:59 +01:00
Tat Dat Duong
e6952b4ba1 fix: relax top-level title/description requirement for converting JSON schema to OpenAI function 2024-03-27 16:35:02 +01:00
2 changed files with 40 additions and 8 deletions

View File

@@ -295,14 +295,12 @@ def convert_to_openai_function(
k in function for k in ("name", "description", "parameters")
):
return function
# a JSON schema with title and description
elif isinstance(function, dict) and all(
k in function for k in ("title", "description", "properties")
):
# a JSON schema
elif isinstance(function, dict) and "properties" in function:
function = function.copy()
return {
"name": function.pop("title"),
"description": function.pop("description"),
"name": function.pop("title", "extract"),
"description": function.pop("description", ""),
"parameters": function,
}
elif isinstance(function, type) and issubclass(function, BaseModel):
@@ -315,8 +313,7 @@ def convert_to_openai_function(
raise ValueError(
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
" as Dict, pydantic.BaseModel, or Callable. If they're a dict they must"
" either be in OpenAI function format or valid JSON schema with top-level"
" 'title' and 'description' keys."
" either be in OpenAI function format or valid JSON schema"
)

View File

@@ -99,6 +99,41 @@ def test_convert_to_openai_function(
assert actual == expected
def test_convert_lax_jsonschema_to_openai_function() -> None:
expected = {
"name": "extract",
"description": "",
"parameters": {
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
},
}
assert (
convert_to_openai_function(
{
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
}
)
== expected
)
@pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()")
def test_function_optional_param() -> None:
@tool