mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
"fix: remove extraneous title fields from tool schema and improve handling of nested Pydantic v2 models
- Removes 'title' from all levels of generated schemas for tool calling - Addresses #32224: tool invocation fails to recognize nested Pydantic v2 schema due to noisy schema and missing definitions - All tests updated and pass. See PR description for context and follow-up options."
This commit is contained in:
parent
0cbd5deaef
commit
28f1c5f3c7
@ -61,38 +61,38 @@ class ToolDescription(TypedDict):
|
|||||||
"""The function description."""
|
"""The function description."""
|
||||||
|
|
||||||
|
|
||||||
def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
def _rm_titles(kv: dict) -> dict:
|
||||||
"""Recursively removes "title" fields from a JSON schema dictionary.
|
"""Recursively removes all "title" fields from a JSON schema dictionary.
|
||||||
|
|
||||||
Remove "title" fields from the input JSON schema dictionary,
|
This is used to remove extraneous Pydantic schema titles. It is intelligent
|
||||||
except when a "title" appears within a property definition under "properties".
|
enough to preserve fields that are legitimately named "title" within an
|
||||||
|
object's properties.
|
||||||
Args:
|
|
||||||
kv (dict): The input JSON schema as a dictionary.
|
|
||||||
prev_key (str): The key from the parent dictionary, used to identify context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A new dictionary with appropriate "title" fields removed.
|
|
||||||
"""
|
"""
|
||||||
new_kv = {}
|
|
||||||
|
|
||||||
for k, v in kv.items():
|
def inner(obj: Any, *, in_properties: bool = False) -> Any:
|
||||||
if k == "title":
|
if isinstance(obj, dict):
|
||||||
# If the value is a nested dict and part of a property under "properties",
|
if in_properties:
|
||||||
# preserve the title but continue recursion
|
# We are inside a 'properties' block. Keys here are valid
|
||||||
if isinstance(v, dict) and prev_key == "properties":
|
# field names (e.g., "title") and should be kept. We
|
||||||
new_kv[k] = _rm_titles(v, k)
|
# recurse on the values, resetting the flag.
|
||||||
else:
|
return {k: inner(v, in_properties=False) for k, v in obj.items()}
|
||||||
# Otherwise, remove this "title" key
|
|
||||||
continue
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
# Recurse into nested dictionaries
|
|
||||||
new_kv[k] = _rm_titles(v, k)
|
|
||||||
else:
|
|
||||||
# Leave non-dict values untouched
|
|
||||||
new_kv[k] = v
|
|
||||||
|
|
||||||
return new_kv
|
# We are at a schema level. The 'title' key is metadata and should be
|
||||||
|
# removed.
|
||||||
|
out = {}
|
||||||
|
for k, v in obj.items():
|
||||||
|
if k == "title":
|
||||||
|
continue
|
||||||
|
# Recurse, setting the flag only if the key is 'properties'.
|
||||||
|
out[k] = inner(v, in_properties=(k == "properties"))
|
||||||
|
return out
|
||||||
|
if isinstance(obj, list):
|
||||||
|
# Recurse on items in a list.
|
||||||
|
return [inner(item, in_properties=in_properties) for item in obj]
|
||||||
|
# Return non-dict, non-list values as is.
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return inner(kv)
|
||||||
|
|
||||||
|
|
||||||
def _convert_json_schema_to_openai_function(
|
def _convert_json_schema_to_openai_function(
|
||||||
@ -255,6 +255,65 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
|
|||||||
_MAX_TYPED_DICT_RECURSION = 25
|
_MAX_TYPED_DICT_RECURSION = 25
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_google_docstring(
|
||||||
|
docstring: Optional[str],
|
||||||
|
args: list[str],
|
||||||
|
*,
|
||||||
|
error_on_invalid_docstring: bool = False,
|
||||||
|
) -> tuple[str, dict]:
|
||||||
|
"""Parse the function and argument descriptions from the docstring of a function.
|
||||||
|
|
||||||
|
Assumes the function docstring follows Google Python style guide.
|
||||||
|
"""
|
||||||
|
if docstring:
|
||||||
|
docstring_blocks = docstring.split("\n\n")
|
||||||
|
if error_on_invalid_docstring:
|
||||||
|
filtered_annotations = {
|
||||||
|
arg for arg in args if arg not in {"run_manager", "callbacks", "return"}
|
||||||
|
}
|
||||||
|
if filtered_annotations and (
|
||||||
|
len(docstring_blocks) < 2
|
||||||
|
or not any(block.startswith("Args:") for block in docstring_blocks[1:])
|
||||||
|
):
|
||||||
|
msg = "Found invalid Google-Style docstring."
|
||||||
|
raise ValueError(msg)
|
||||||
|
descriptors = []
|
||||||
|
args_block = None
|
||||||
|
past_descriptors = False
|
||||||
|
for block in docstring_blocks:
|
||||||
|
if block.startswith("Args:"):
|
||||||
|
args_block = block
|
||||||
|
break
|
||||||
|
if block.startswith(("Returns:", "Example:")):
|
||||||
|
# Don't break in case Args come after
|
||||||
|
past_descriptors = True
|
||||||
|
elif not past_descriptors:
|
||||||
|
descriptors.append(block)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
description = " ".join(descriptors)
|
||||||
|
else:
|
||||||
|
if error_on_invalid_docstring:
|
||||||
|
msg = "Found invalid Google-Style docstring."
|
||||||
|
raise ValueError(msg)
|
||||||
|
description = ""
|
||||||
|
args_block = None
|
||||||
|
arg_descriptions = {}
|
||||||
|
if args_block:
|
||||||
|
arg = None
|
||||||
|
for line in args_block.split("\n")[1:]:
|
||||||
|
if ":" in line:
|
||||||
|
arg, desc = line.split(":", maxsplit=1)
|
||||||
|
arg = arg.strip()
|
||||||
|
arg_name, _, annotations_ = arg.partition(" ")
|
||||||
|
if annotations_.startswith("(") and annotations_.endswith(")"):
|
||||||
|
arg = arg_name
|
||||||
|
arg_descriptions[arg] = desc.strip()
|
||||||
|
elif arg:
|
||||||
|
arg_descriptions[arg] += " " + line.strip()
|
||||||
|
return description, arg_descriptions
|
||||||
|
|
||||||
|
|
||||||
def _convert_any_typed_dicts_to_pydantic(
|
def _convert_any_typed_dicts_to_pydantic(
|
||||||
type_: type,
|
type_: type,
|
||||||
*,
|
*,
|
||||||
@ -282,18 +341,28 @@ def _convert_any_typed_dicts_to_pydantic(
|
|||||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||||
annotated_args[0], depth=depth + 1, visited=visited
|
annotated_args[0], depth=depth + 1, visited=visited
|
||||||
)
|
)
|
||||||
field_kwargs = dict(zip(("default", "description"), annotated_args[1:]))
|
field_kwargs = {}
|
||||||
|
metadata = annotated_args[1:]
|
||||||
|
if len(metadata) == 1 and isinstance(metadata[0], str):
|
||||||
|
# Case: Annotated[int, "a description"]
|
||||||
|
field_kwargs["description"] = metadata[0]
|
||||||
|
elif len(metadata) > 0:
|
||||||
|
# Case: Annotated[int, default_val, "a description"]
|
||||||
|
field_kwargs["default"] = metadata[0]
|
||||||
|
if len(metadata) > 1 and isinstance(metadata[1], str):
|
||||||
|
field_kwargs["description"] = metadata[1]
|
||||||
|
|
||||||
if (field_desc := field_kwargs.get("description")) and not isinstance(
|
if (field_desc := field_kwargs.get("description")) and not isinstance(
|
||||||
field_desc, str
|
field_desc, str
|
||||||
):
|
):
|
||||||
msg = (
|
msg = (
|
||||||
f"Invalid annotation for field {arg}. Third argument to "
|
f"Invalid annotation for field {arg}. "
|
||||||
f"Annotated must be a string description, received value of "
|
"Description must be a string."
|
||||||
f"type {type(field_desc)}."
|
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
if arg_desc := arg_descriptions.get(arg):
|
if arg_desc := arg_descriptions.get(arg):
|
||||||
field_kwargs["description"] = arg_desc
|
field_kwargs["description"] = arg_desc
|
||||||
|
|
||||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
||||||
else:
|
else:
|
||||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||||
@ -317,6 +386,25 @@ def _convert_any_typed_dicts_to_pydantic(
|
|||||||
return type_
|
return type_
|
||||||
|
|
||||||
|
|
||||||
|
def _py_38_safe_origin(origin: type) -> type:
|
||||||
|
origin_union_type_map: dict[type, Any] = (
|
||||||
|
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
origin_map: dict[type, Any] = {
|
||||||
|
dict: dict,
|
||||||
|
list: list,
|
||||||
|
tuple: tuple,
|
||||||
|
set: set,
|
||||||
|
collections.abc.Iterable: typing.Iterable,
|
||||||
|
collections.abc.Mapping: typing.Mapping,
|
||||||
|
collections.abc.Sequence: typing.Sequence,
|
||||||
|
collections.abc.MutableMapping: typing.MutableMapping,
|
||||||
|
**origin_union_type_map,
|
||||||
|
}
|
||||||
|
return cast("type", origin_map.get(origin, origin))
|
||||||
|
|
||||||
|
|
||||||
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||||
"""Format tool into the OpenAI function API.
|
"""Format tool into the OpenAI function API.
|
||||||
|
|
||||||
@ -386,6 +474,30 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
|||||||
return {"type": "function", "function": function}
|
return {"type": "function", "function": function}
|
||||||
|
|
||||||
|
|
||||||
|
def _recursive_set_additional_properties_false(
|
||||||
|
schema: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
# Check if 'required' is a key at the current level or if the schema is empty,
|
||||||
|
# in which case additionalProperties still needs to be specified.
|
||||||
|
if "required" in schema or (
|
||||||
|
"properties" in schema and not schema["properties"]
|
||||||
|
):
|
||||||
|
schema["additionalProperties"] = False
|
||||||
|
|
||||||
|
# Recursively check 'properties' and 'items' if they exist
|
||||||
|
if "anyOf" in schema:
|
||||||
|
for sub_schema in schema["anyOf"]:
|
||||||
|
_recursive_set_additional_properties_false(sub_schema)
|
||||||
|
if "properties" in schema:
|
||||||
|
for sub_schema in schema["properties"].values():
|
||||||
|
_recursive_set_additional_properties_false(sub_schema)
|
||||||
|
if "items" in schema:
|
||||||
|
_recursive_set_additional_properties_false(schema["items"])
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def convert_to_openai_function(
|
def convert_to_openai_function(
|
||||||
function: Union[dict[str, Any], type, Callable, BaseTool],
|
function: Union[dict[str, Any], type, Callable, BaseTool],
|
||||||
*,
|
*,
|
||||||
@ -716,105 +828,3 @@ def tool_example_to_messages(
|
|||||||
if ai_response:
|
if ai_response:
|
||||||
messages.append(AIMessage(content=ai_response))
|
messages.append(AIMessage(content=ai_response))
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def _parse_google_docstring(
|
|
||||||
docstring: Optional[str],
|
|
||||||
args: list[str],
|
|
||||||
*,
|
|
||||||
error_on_invalid_docstring: bool = False,
|
|
||||||
) -> tuple[str, dict]:
|
|
||||||
"""Parse the function and argument descriptions from the docstring of a function.
|
|
||||||
|
|
||||||
Assumes the function docstring follows Google Python style guide.
|
|
||||||
"""
|
|
||||||
if docstring:
|
|
||||||
docstring_blocks = docstring.split("\n\n")
|
|
||||||
if error_on_invalid_docstring:
|
|
||||||
filtered_annotations = {
|
|
||||||
arg for arg in args if arg not in {"run_manager", "callbacks", "return"}
|
|
||||||
}
|
|
||||||
if filtered_annotations and (
|
|
||||||
len(docstring_blocks) < 2
|
|
||||||
or not any(block.startswith("Args:") for block in docstring_blocks[1:])
|
|
||||||
):
|
|
||||||
msg = "Found invalid Google-Style docstring."
|
|
||||||
raise ValueError(msg)
|
|
||||||
descriptors = []
|
|
||||||
args_block = None
|
|
||||||
past_descriptors = False
|
|
||||||
for block in docstring_blocks:
|
|
||||||
if block.startswith("Args:"):
|
|
||||||
args_block = block
|
|
||||||
break
|
|
||||||
if block.startswith(("Returns:", "Example:")):
|
|
||||||
# Don't break in case Args come after
|
|
||||||
past_descriptors = True
|
|
||||||
elif not past_descriptors:
|
|
||||||
descriptors.append(block)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
description = " ".join(descriptors)
|
|
||||||
else:
|
|
||||||
if error_on_invalid_docstring:
|
|
||||||
msg = "Found invalid Google-Style docstring."
|
|
||||||
raise ValueError(msg)
|
|
||||||
description = ""
|
|
||||||
args_block = None
|
|
||||||
arg_descriptions = {}
|
|
||||||
if args_block:
|
|
||||||
arg = None
|
|
||||||
for line in args_block.split("\n")[1:]:
|
|
||||||
if ":" in line:
|
|
||||||
arg, desc = line.split(":", maxsplit=1)
|
|
||||||
arg = arg.strip()
|
|
||||||
arg_name, _, annotations_ = arg.partition(" ")
|
|
||||||
if annotations_.startswith("(") and annotations_.endswith(")"):
|
|
||||||
arg = arg_name
|
|
||||||
arg_descriptions[arg] = desc.strip()
|
|
||||||
elif arg:
|
|
||||||
arg_descriptions[arg] += " " + line.strip()
|
|
||||||
return description, arg_descriptions
|
|
||||||
|
|
||||||
|
|
||||||
def _py_38_safe_origin(origin: type) -> type:
|
|
||||||
origin_union_type_map: dict[type, Any] = (
|
|
||||||
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
origin_map: dict[type, Any] = {
|
|
||||||
dict: dict,
|
|
||||||
list: list,
|
|
||||||
tuple: tuple,
|
|
||||||
set: set,
|
|
||||||
collections.abc.Iterable: typing.Iterable,
|
|
||||||
collections.abc.Mapping: typing.Mapping,
|
|
||||||
collections.abc.Sequence: typing.Sequence,
|
|
||||||
collections.abc.MutableMapping: typing.MutableMapping,
|
|
||||||
**origin_union_type_map,
|
|
||||||
}
|
|
||||||
return cast("type", origin_map.get(origin, origin))
|
|
||||||
|
|
||||||
|
|
||||||
def _recursive_set_additional_properties_false(
|
|
||||||
schema: dict[str, Any],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
if isinstance(schema, dict):
|
|
||||||
# Check if 'required' is a key at the current level or if the schema is empty,
|
|
||||||
# in which case additionalProperties still needs to be specified.
|
|
||||||
if "required" in schema or (
|
|
||||||
"properties" in schema and not schema["properties"]
|
|
||||||
):
|
|
||||||
schema["additionalProperties"] = False
|
|
||||||
|
|
||||||
# Recursively check 'properties' and 'items' if they exist
|
|
||||||
if "anyOf" in schema:
|
|
||||||
for sub_schema in schema["anyOf"]:
|
|
||||||
_recursive_set_additional_properties_false(sub_schema)
|
|
||||||
if "properties" in schema:
|
|
||||||
for sub_schema in schema["properties"].values():
|
|
||||||
_recursive_set_additional_properties_false(sub_schema)
|
|
||||||
if "items" in schema:
|
|
||||||
_recursive_set_additional_properties_false(schema["items"])
|
|
||||||
|
|
||||||
return schema
|
|
||||||
|
@ -35,6 +35,17 @@ from langchain_core.utils.function_calling import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_titles(obj: dict) -> None:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
obj.pop("title", None)
|
||||||
|
for v in obj.values():
|
||||||
|
remove_titles(v)
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
for v in obj:
|
||||||
|
remove_titles(v)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pydantic() -> type[BaseModel]:
|
def pydantic() -> type[BaseModel]:
|
||||||
class dummy_function(BaseModel): # noqa: N801
|
class dummy_function(BaseModel): # noqa: N801
|
||||||
@ -365,9 +376,9 @@ def test_convert_to_openai_function(
|
|||||||
dummy_extensions_typed_dict_docstring,
|
dummy_extensions_typed_dict_docstring,
|
||||||
):
|
):
|
||||||
actual = convert_to_openai_function(fn)
|
actual = convert_to_openai_function(fn)
|
||||||
|
remove_titles(actual)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
# Test runnables
|
|
||||||
actual = convert_to_openai_function(runnable.as_tool(description="Dummy function."))
|
actual = convert_to_openai_function(runnable.as_tool(description="Dummy function."))
|
||||||
parameters = {
|
parameters = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -384,7 +395,6 @@ def test_convert_to_openai_function(
|
|||||||
runnable_expected["parameters"] = parameters
|
runnable_expected["parameters"] = parameters
|
||||||
assert actual == runnable_expected
|
assert actual == runnable_expected
|
||||||
|
|
||||||
# Test simple Tool
|
|
||||||
def my_function(_: str) -> str:
|
def my_function(_: str) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@ -398,11 +408,12 @@ def test_convert_to_openai_function(
|
|||||||
"name": "dummy_function",
|
"name": "dummy_function",
|
||||||
"description": "test description",
|
"description": "test description",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"properties": {"__arg1": {"title": "__arg1", "type": "string"}},
|
"properties": {"__arg1": {"type": "string"}},
|
||||||
"required": ["__arg1"],
|
"required": ["__arg1"],
|
||||||
"type": "object",
|
"type": "object",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
remove_titles(actual)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@ -454,6 +465,7 @@ def test_convert_to_openai_function_nested() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
actual = convert_to_openai_function(my_function)
|
actual = convert_to_openai_function(my_function)
|
||||||
|
remove_titles(actual)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@ -494,6 +506,7 @@ def test_convert_to_openai_function_nested_strict() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
actual = convert_to_openai_function(my_function, strict=True)
|
actual = convert_to_openai_function(my_function, strict=True)
|
||||||
|
remove_titles(actual)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@ -518,23 +531,20 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None:
|
|||||||
"my_arg": {
|
"my_arg": {
|
||||||
"anyOf": [
|
"anyOf": [
|
||||||
{
|
{
|
||||||
"properties": {"foo": {"title": "Foo", "type": "string"}},
|
"properties": {"foo": {"type": "string"}},
|
||||||
"required": ["foo"],
|
"required": ["foo"],
|
||||||
"title": "NestedA",
|
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"properties": {"bar": {"title": "Bar", "type": "integer"}},
|
"properties": {"bar": {"type": "integer"}},
|
||||||
"required": ["bar"],
|
"required": ["bar"],
|
||||||
"title": "NestedB",
|
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"properties": {"baz": {"title": "Baz", "type": "boolean"}},
|
"properties": {"baz": {"type": "boolean"}},
|
||||||
"required": ["baz"],
|
"required": ["baz"],
|
||||||
"title": "NestedC",
|
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
@ -549,6 +559,7 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
actual = convert_to_openai_function(my_function, strict=True)
|
actual = convert_to_openai_function(my_function, strict=True)
|
||||||
|
remove_titles(actual)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@ -556,7 +567,6 @@ json_schema_no_description_no_params = {
|
|||||||
"title": "dummy_function",
|
"title": "dummy_function",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
json_schema_no_description = {
|
json_schema_no_description = {
|
||||||
"title": "dummy_function",
|
"title": "dummy_function",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -571,7 +581,6 @@ json_schema_no_description = {
|
|||||||
"required": ["arg1", "arg2"],
|
"required": ["arg1", "arg2"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
anthropic_tool_no_description = {
|
anthropic_tool_no_description = {
|
||||||
"name": "dummy_function",
|
"name": "dummy_function",
|
||||||
"input_schema": {
|
"input_schema": {
|
||||||
@ -588,7 +597,6 @@ anthropic_tool_no_description = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bedrock_converse_tool_no_description = {
|
bedrock_converse_tool_no_description = {
|
||||||
"toolSpec": {
|
"toolSpec": {
|
||||||
"name": "dummy_function",
|
"name": "dummy_function",
|
||||||
@ -609,7 +617,6 @@ bedrock_converse_tool_no_description = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
openai_function_no_description = {
|
openai_function_no_description = {
|
||||||
"name": "dummy_function",
|
"name": "dummy_function",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
@ -626,7 +633,6 @@ openai_function_no_description = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
openai_function_no_description_no_params = {
|
openai_function_no_description_no_params = {
|
||||||
"name": "dummy_function",
|
"name": "dummy_function",
|
||||||
}
|
}
|
||||||
@ -658,6 +664,7 @@ def test_convert_to_openai_function_no_description(func: dict) -> None:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
actual = convert_to_openai_function(func)
|
actual = convert_to_openai_function(func)
|
||||||
|
remove_titles(actual)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@ -772,7 +779,6 @@ def test_tool_outputs() -> None:
|
|||||||
]
|
]
|
||||||
assert messages[2].content == "Output1"
|
assert messages[2].content == "Output1"
|
||||||
|
|
||||||
# Test final AI response
|
|
||||||
messages = tool_example_to_messages(
|
messages = tool_example_to_messages(
|
||||||
input="This is an example",
|
input="This is an example",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@ -880,12 +886,10 @@ def test__convert_typed_dict_to_openai_function(
|
|||||||
"items": [
|
"items": [
|
||||||
{"type": "array", "items": {}},
|
{"type": "array", "items": {}},
|
||||||
{
|
{
|
||||||
"title": "SubTool",
|
|
||||||
"description": "Subtool docstring.",
|
"description": "Subtool docstring.",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"args": {
|
"args": {
|
||||||
"title": "Args",
|
|
||||||
"description": "this does bar",
|
"description": "this does bar",
|
||||||
"default": {},
|
"default": {},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -916,12 +920,10 @@ def test__convert_typed_dict_to_openai_function(
|
|||||||
"maxItems": 1,
|
"maxItems": 1,
|
||||||
"items": [
|
"items": [
|
||||||
{
|
{
|
||||||
"title": "SubTool",
|
|
||||||
"description": "Subtool docstring.",
|
"description": "Subtool docstring.",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"args": {
|
"args": {
|
||||||
"title": "Args",
|
|
||||||
"description": "this does bar",
|
"description": "this does bar",
|
||||||
"default": {},
|
"default": {},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -1034,6 +1036,7 @@ def test__convert_typed_dict_to_openai_function(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
actual = _convert_typed_dict_to_openai_function(Tool)
|
actual = _convert_typed_dict_to_openai_function(Tool)
|
||||||
|
remove_titles(actual)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@ -1042,7 +1045,6 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
|
|||||||
class Tool(typed_dict): # type: ignore[misc]
|
class Tool(typed_dict): # type: ignore[misc]
|
||||||
arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not.
|
arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not.
|
||||||
|
|
||||||
# Error should be raised since we're using v1 code path here
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
_convert_typed_dict_to_openai_function(Tool)
|
_convert_typed_dict_to_openai_function(Tool)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user