"fix: remove extraneous title fields from tool schema and improve handling of nested Pydantic v2 models

- Removes 'title' from all levels of generated schemas for tool calling
- Addresses #32224: tool invocation fails to recognize nested Pydantic v2 schema due to noisy schema and missing definitions
- All tests updated and pass. See PR description for context and follow-up options."
This commit is contained in:
RN 2025-07-26 19:18:14 -07:00
parent 0cbd5deaef
commit 28f1c5f3c7
2 changed files with 166 additions and 154 deletions

View File

@ -61,38 +61,38 @@ class ToolDescription(TypedDict):
"""The function description.""" """The function description."""
def _rm_titles(kv: dict, prev_key: str = "") -> dict: def _rm_titles(kv: dict) -> dict:
"""Recursively removes "title" fields from a JSON schema dictionary. """Recursively removes all "title" fields from a JSON schema dictionary.
Remove "title" fields from the input JSON schema dictionary, This is used to remove extraneous Pydantic schema titles. It is intelligent
except when a "title" appears within a property definition under "properties". enough to preserve fields that are legitimately named "title" within an
object's properties.
Args:
kv (dict): The input JSON schema as a dictionary.
prev_key (str): The key from the parent dictionary, used to identify context.
Returns:
dict: A new dictionary with appropriate "title" fields removed.
""" """
new_kv = {}
for k, v in kv.items(): def inner(obj: Any, *, in_properties: bool = False) -> Any:
if k == "title": if isinstance(obj, dict):
# If the value is a nested dict and part of a property under "properties", if in_properties:
# preserve the title but continue recursion # We are inside a 'properties' block. Keys here are valid
if isinstance(v, dict) and prev_key == "properties": # field names (e.g., "title") and should be kept. We
new_kv[k] = _rm_titles(v, k) # recurse on the values, resetting the flag.
else: return {k: inner(v, in_properties=False) for k, v in obj.items()}
# Otherwise, remove this "title" key
continue
elif isinstance(v, dict):
# Recurse into nested dictionaries
new_kv[k] = _rm_titles(v, k)
else:
# Leave non-dict values untouched
new_kv[k] = v
return new_kv # We are at a schema level. The 'title' key is metadata and should be
# removed.
out = {}
for k, v in obj.items():
if k == "title":
continue
# Recurse, setting the flag only if the key is 'properties'.
out[k] = inner(v, in_properties=(k == "properties"))
return out
if isinstance(obj, list):
# Recurse on items in a list.
return [inner(item, in_properties=in_properties) for item in obj]
# Return non-dict, non-list values as is.
return obj
return inner(kv)
def _convert_json_schema_to_openai_function( def _convert_json_schema_to_openai_function(
@ -255,6 +255,65 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
_MAX_TYPED_DICT_RECURSION = 25 _MAX_TYPED_DICT_RECURSION = 25
def _parse_google_docstring(
docstring: Optional[str],
args: list[str],
*,
error_on_invalid_docstring: bool = False,
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
if docstring:
docstring_blocks = docstring.split("\n\n")
if error_on_invalid_docstring:
filtered_annotations = {
arg for arg in args if arg not in {"run_manager", "callbacks", "return"}
}
if filtered_annotations and (
len(docstring_blocks) < 2
or not any(block.startswith("Args:") for block in docstring_blocks[1:])
):
msg = "Found invalid Google-Style docstring."
raise ValueError(msg)
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
if block.startswith(("Returns:", "Example:")):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
if error_on_invalid_docstring:
msg = "Found invalid Google-Style docstring."
raise ValueError(msg)
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg = arg.strip()
arg_name, _, annotations_ = arg.partition(" ")
if annotations_.startswith("(") and annotations_.endswith(")"):
arg = arg_name
arg_descriptions[arg] = desc.strip()
elif arg:
arg_descriptions[arg] += " " + line.strip()
return description, arg_descriptions
def _convert_any_typed_dicts_to_pydantic( def _convert_any_typed_dicts_to_pydantic(
type_: type, type_: type,
*, *,
@ -282,18 +341,28 @@ def _convert_any_typed_dicts_to_pydantic(
new_arg_type = _convert_any_typed_dicts_to_pydantic( new_arg_type = _convert_any_typed_dicts_to_pydantic(
annotated_args[0], depth=depth + 1, visited=visited annotated_args[0], depth=depth + 1, visited=visited
) )
field_kwargs = dict(zip(("default", "description"), annotated_args[1:])) field_kwargs = {}
metadata = annotated_args[1:]
if len(metadata) == 1 and isinstance(metadata[0], str):
# Case: Annotated[int, "a description"]
field_kwargs["description"] = metadata[0]
elif len(metadata) > 0:
# Case: Annotated[int, default_val, "a description"]
field_kwargs["default"] = metadata[0]
if len(metadata) > 1 and isinstance(metadata[1], str):
field_kwargs["description"] = metadata[1]
if (field_desc := field_kwargs.get("description")) and not isinstance( if (field_desc := field_kwargs.get("description")) and not isinstance(
field_desc, str field_desc, str
): ):
msg = ( msg = (
f"Invalid annotation for field {arg}. Third argument to " f"Invalid annotation for field {arg}. "
f"Annotated must be a string description, received value of " "Description must be a string."
f"type {type(field_desc)}."
) )
raise ValueError(msg) raise ValueError(msg)
if arg_desc := arg_descriptions.get(arg): if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc field_kwargs["description"] = arg_desc
fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
else: else:
new_arg_type = _convert_any_typed_dicts_to_pydantic( new_arg_type = _convert_any_typed_dicts_to_pydantic(
@ -317,6 +386,25 @@ def _convert_any_typed_dicts_to_pydantic(
return type_ return type_
def _py_38_safe_origin(origin: type) -> type:
origin_union_type_map: dict[type, Any] = (
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
)
origin_map: dict[type, Any] = {
dict: dict,
list: list,
tuple: tuple,
set: set,
collections.abc.Iterable: typing.Iterable,
collections.abc.Mapping: typing.Mapping,
collections.abc.Sequence: typing.Sequence,
collections.abc.MutableMapping: typing.MutableMapping,
**origin_union_type_map,
}
return cast("type", origin_map.get(origin, origin))
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tool into the OpenAI function API. """Format tool into the OpenAI function API.
@ -386,6 +474,30 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
return {"type": "function", "function": function} return {"type": "function", "function": function}
def _recursive_set_additional_properties_false(
schema: dict[str, Any],
) -> dict[str, Any]:
if isinstance(schema, dict):
# Check if 'required' is a key at the current level or if the schema is empty,
# in which case additionalProperties still needs to be specified.
if "required" in schema or (
"properties" in schema and not schema["properties"]
):
schema["additionalProperties"] = False
# Recursively check 'properties' and 'items' if they exist
if "anyOf" in schema:
for sub_schema in schema["anyOf"]:
_recursive_set_additional_properties_false(sub_schema)
if "properties" in schema:
for sub_schema in schema["properties"].values():
_recursive_set_additional_properties_false(sub_schema)
if "items" in schema:
_recursive_set_additional_properties_false(schema["items"])
return schema
def convert_to_openai_function( def convert_to_openai_function(
function: Union[dict[str, Any], type, Callable, BaseTool], function: Union[dict[str, Any], type, Callable, BaseTool],
*, *,
@ -716,105 +828,3 @@ def tool_example_to_messages(
if ai_response: if ai_response:
messages.append(AIMessage(content=ai_response)) messages.append(AIMessage(content=ai_response))
return messages return messages
def _parse_google_docstring(
docstring: Optional[str],
args: list[str],
*,
error_on_invalid_docstring: bool = False,
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
if docstring:
docstring_blocks = docstring.split("\n\n")
if error_on_invalid_docstring:
filtered_annotations = {
arg for arg in args if arg not in {"run_manager", "callbacks", "return"}
}
if filtered_annotations and (
len(docstring_blocks) < 2
or not any(block.startswith("Args:") for block in docstring_blocks[1:])
):
msg = "Found invalid Google-Style docstring."
raise ValueError(msg)
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
if block.startswith(("Returns:", "Example:")):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
if error_on_invalid_docstring:
msg = "Found invalid Google-Style docstring."
raise ValueError(msg)
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":", maxsplit=1)
arg = arg.strip()
arg_name, _, annotations_ = arg.partition(" ")
if annotations_.startswith("(") and annotations_.endswith(")"):
arg = arg_name
arg_descriptions[arg] = desc.strip()
elif arg:
arg_descriptions[arg] += " " + line.strip()
return description, arg_descriptions
def _py_38_safe_origin(origin: type) -> type:
origin_union_type_map: dict[type, Any] = (
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
)
origin_map: dict[type, Any] = {
dict: dict,
list: list,
tuple: tuple,
set: set,
collections.abc.Iterable: typing.Iterable,
collections.abc.Mapping: typing.Mapping,
collections.abc.Sequence: typing.Sequence,
collections.abc.MutableMapping: typing.MutableMapping,
**origin_union_type_map,
}
return cast("type", origin_map.get(origin, origin))
def _recursive_set_additional_properties_false(
schema: dict[str, Any],
) -> dict[str, Any]:
if isinstance(schema, dict):
# Check if 'required' is a key at the current level or if the schema is empty,
# in which case additionalProperties still needs to be specified.
if "required" in schema or (
"properties" in schema and not schema["properties"]
):
schema["additionalProperties"] = False
# Recursively check 'properties' and 'items' if they exist
if "anyOf" in schema:
for sub_schema in schema["anyOf"]:
_recursive_set_additional_properties_false(sub_schema)
if "properties" in schema:
for sub_schema in schema["properties"].values():
_recursive_set_additional_properties_false(sub_schema)
if "items" in schema:
_recursive_set_additional_properties_false(schema["items"])
return schema

View File

@ -35,6 +35,17 @@ from langchain_core.utils.function_calling import (
) )
def remove_titles(obj: dict) -> None:
if isinstance(obj, dict):
obj.pop("title", None)
for v in obj.values():
remove_titles(v)
elif isinstance(obj, list):
for v in obj:
remove_titles(v)
return obj
@pytest.fixture @pytest.fixture
def pydantic() -> type[BaseModel]: def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801 class dummy_function(BaseModel): # noqa: N801
@ -365,9 +376,9 @@ def test_convert_to_openai_function(
dummy_extensions_typed_dict_docstring, dummy_extensions_typed_dict_docstring,
): ):
actual = convert_to_openai_function(fn) actual = convert_to_openai_function(fn)
remove_titles(actual)
assert actual == expected assert actual == expected
# Test runnables
actual = convert_to_openai_function(runnable.as_tool(description="Dummy function.")) actual = convert_to_openai_function(runnable.as_tool(description="Dummy function."))
parameters = { parameters = {
"type": "object", "type": "object",
@ -384,7 +395,6 @@ def test_convert_to_openai_function(
runnable_expected["parameters"] = parameters runnable_expected["parameters"] = parameters
assert actual == runnable_expected assert actual == runnable_expected
# Test simple Tool
def my_function(_: str) -> str: def my_function(_: str) -> str:
return "" return ""
@ -398,11 +408,12 @@ def test_convert_to_openai_function(
"name": "dummy_function", "name": "dummy_function",
"description": "test description", "description": "test description",
"parameters": { "parameters": {
"properties": {"__arg1": {"title": "__arg1", "type": "string"}}, "properties": {"__arg1": {"type": "string"}},
"required": ["__arg1"], "required": ["__arg1"],
"type": "object", "type": "object",
}, },
} }
remove_titles(actual)
assert actual == expected assert actual == expected
@ -454,6 +465,7 @@ def test_convert_to_openai_function_nested() -> None:
} }
actual = convert_to_openai_function(my_function) actual = convert_to_openai_function(my_function)
remove_titles(actual)
assert actual == expected assert actual == expected
@ -494,6 +506,7 @@ def test_convert_to_openai_function_nested_strict() -> None:
} }
actual = convert_to_openai_function(my_function, strict=True) actual = convert_to_openai_function(my_function, strict=True)
remove_titles(actual)
assert actual == expected assert actual == expected
@ -518,23 +531,20 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None:
"my_arg": { "my_arg": {
"anyOf": [ "anyOf": [
{ {
"properties": {"foo": {"title": "Foo", "type": "string"}}, "properties": {"foo": {"type": "string"}},
"required": ["foo"], "required": ["foo"],
"title": "NestedA",
"type": "object", "type": "object",
"additionalProperties": False, "additionalProperties": False,
}, },
{ {
"properties": {"bar": {"title": "Bar", "type": "integer"}}, "properties": {"bar": {"type": "integer"}},
"required": ["bar"], "required": ["bar"],
"title": "NestedB",
"type": "object", "type": "object",
"additionalProperties": False, "additionalProperties": False,
}, },
{ {
"properties": {"baz": {"title": "Baz", "type": "boolean"}}, "properties": {"baz": {"type": "boolean"}},
"required": ["baz"], "required": ["baz"],
"title": "NestedC",
"type": "object", "type": "object",
"additionalProperties": False, "additionalProperties": False,
}, },
@ -549,6 +559,7 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None:
} }
actual = convert_to_openai_function(my_function, strict=True) actual = convert_to_openai_function(my_function, strict=True)
remove_titles(actual)
assert actual == expected assert actual == expected
@ -556,7 +567,6 @@ json_schema_no_description_no_params = {
"title": "dummy_function", "title": "dummy_function",
} }
json_schema_no_description = { json_schema_no_description = {
"title": "dummy_function", "title": "dummy_function",
"type": "object", "type": "object",
@ -571,7 +581,6 @@ json_schema_no_description = {
"required": ["arg1", "arg2"], "required": ["arg1", "arg2"],
} }
anthropic_tool_no_description = { anthropic_tool_no_description = {
"name": "dummy_function", "name": "dummy_function",
"input_schema": { "input_schema": {
@ -588,7 +597,6 @@ anthropic_tool_no_description = {
}, },
} }
bedrock_converse_tool_no_description = { bedrock_converse_tool_no_description = {
"toolSpec": { "toolSpec": {
"name": "dummy_function", "name": "dummy_function",
@ -609,7 +617,6 @@ bedrock_converse_tool_no_description = {
} }
} }
openai_function_no_description = { openai_function_no_description = {
"name": "dummy_function", "name": "dummy_function",
"parameters": { "parameters": {
@ -626,7 +633,6 @@ openai_function_no_description = {
}, },
} }
openai_function_no_description_no_params = { openai_function_no_description_no_params = {
"name": "dummy_function", "name": "dummy_function",
} }
@ -658,6 +664,7 @@ def test_convert_to_openai_function_no_description(func: dict) -> None:
}, },
} }
actual = convert_to_openai_function(func) actual = convert_to_openai_function(func)
remove_titles(actual)
assert actual == expected assert actual == expected
@ -772,7 +779,6 @@ def test_tool_outputs() -> None:
] ]
assert messages[2].content == "Output1" assert messages[2].content == "Output1"
# Test final AI response
messages = tool_example_to_messages( messages = tool_example_to_messages(
input="This is an example", input="This is an example",
tool_calls=[ tool_calls=[
@ -880,12 +886,10 @@ def test__convert_typed_dict_to_openai_function(
"items": [ "items": [
{"type": "array", "items": {}}, {"type": "array", "items": {}},
{ {
"title": "SubTool",
"description": "Subtool docstring.", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
"title": "Args",
"description": "this does bar", "description": "this does bar",
"default": {}, "default": {},
"type": "object", "type": "object",
@ -916,12 +920,10 @@ def test__convert_typed_dict_to_openai_function(
"maxItems": 1, "maxItems": 1,
"items": [ "items": [
{ {
"title": "SubTool",
"description": "Subtool docstring.", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
"title": "Args",
"description": "this does bar", "description": "this does bar",
"default": {}, "default": {},
"type": "object", "type": "object",
@ -1034,6 +1036,7 @@ def test__convert_typed_dict_to_openai_function(
}, },
} }
actual = _convert_typed_dict_to_openai_function(Tool) actual = _convert_typed_dict_to_openai_function(Tool)
remove_titles(actual)
assert actual == expected assert actual == expected
@ -1042,7 +1045,6 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
class Tool(typed_dict): # type: ignore[misc] class Tool(typed_dict): # type: ignore[misc]
arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not. arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not.
# Error should be raised since we're using v1 code path here
with pytest.raises(TypeError): with pytest.raises(TypeError):
_convert_typed_dict_to_openai_function(Tool) _convert_typed_dict_to_openai_function(Tool)