mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-30 16:24:24 +00:00
"fix: remove extraneous title fields from tool schema and improve handling of nested Pydantic v2 models
- Removes 'title' from all levels of generated schemas for tool calling - Addresses #32224: tool invocation fails to recognize nested Pydantic v2 schema due to noisy schema and missing definitions - All tests updated and pass. See PR description for context and follow-up options."
This commit is contained in:
parent
0cbd5deaef
commit
28f1c5f3c7
@ -61,38 +61,38 @@ class ToolDescription(TypedDict):
|
||||
"""The function description."""
|
||||
|
||||
|
||||
def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
||||
"""Recursively removes "title" fields from a JSON schema dictionary.
|
||||
def _rm_titles(kv: dict) -> dict:
|
||||
"""Recursively removes all "title" fields from a JSON schema dictionary.
|
||||
|
||||
Remove "title" fields from the input JSON schema dictionary,
|
||||
except when a "title" appears within a property definition under "properties".
|
||||
|
||||
Args:
|
||||
kv (dict): The input JSON schema as a dictionary.
|
||||
prev_key (str): The key from the parent dictionary, used to identify context.
|
||||
|
||||
Returns:
|
||||
dict: A new dictionary with appropriate "title" fields removed.
|
||||
This is used to remove extraneous Pydantic schema titles. It is intelligent
|
||||
enough to preserve fields that are legitimately named "title" within an
|
||||
object's properties.
|
||||
"""
|
||||
new_kv = {}
|
||||
|
||||
for k, v in kv.items():
|
||||
if k == "title":
|
||||
# If the value is a nested dict and part of a property under "properties",
|
||||
# preserve the title but continue recursion
|
||||
if isinstance(v, dict) and prev_key == "properties":
|
||||
new_kv[k] = _rm_titles(v, k)
|
||||
else:
|
||||
# Otherwise, remove this "title" key
|
||||
continue
|
||||
elif isinstance(v, dict):
|
||||
# Recurse into nested dictionaries
|
||||
new_kv[k] = _rm_titles(v, k)
|
||||
else:
|
||||
# Leave non-dict values untouched
|
||||
new_kv[k] = v
|
||||
def inner(obj: Any, *, in_properties: bool = False) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
if in_properties:
|
||||
# We are inside a 'properties' block. Keys here are valid
|
||||
# field names (e.g., "title") and should be kept. We
|
||||
# recurse on the values, resetting the flag.
|
||||
return {k: inner(v, in_properties=False) for k, v in obj.items()}
|
||||
|
||||
return new_kv
|
||||
# We are at a schema level. The 'title' key is metadata and should be
|
||||
# removed.
|
||||
out = {}
|
||||
for k, v in obj.items():
|
||||
if k == "title":
|
||||
continue
|
||||
# Recurse, setting the flag only if the key is 'properties'.
|
||||
out[k] = inner(v, in_properties=(k == "properties"))
|
||||
return out
|
||||
if isinstance(obj, list):
|
||||
# Recurse on items in a list.
|
||||
return [inner(item, in_properties=in_properties) for item in obj]
|
||||
# Return non-dict, non-list values as is.
|
||||
return obj
|
||||
|
||||
return inner(kv)
|
||||
|
||||
|
||||
def _convert_json_schema_to_openai_function(
|
||||
@ -255,6 +255,65 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
|
||||
_MAX_TYPED_DICT_RECURSION = 25
|
||||
|
||||
|
||||
def _parse_google_docstring(
|
||||
docstring: Optional[str],
|
||||
args: list[str],
|
||||
*,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
"""
|
||||
if docstring:
|
||||
docstring_blocks = docstring.split("\n\n")
|
||||
if error_on_invalid_docstring:
|
||||
filtered_annotations = {
|
||||
arg for arg in args if arg not in {"run_manager", "callbacks", "return"}
|
||||
}
|
||||
if filtered_annotations and (
|
||||
len(docstring_blocks) < 2
|
||||
or not any(block.startswith("Args:") for block in docstring_blocks[1:])
|
||||
):
|
||||
msg = "Found invalid Google-Style docstring."
|
||||
raise ValueError(msg)
|
||||
descriptors = []
|
||||
args_block = None
|
||||
past_descriptors = False
|
||||
for block in docstring_blocks:
|
||||
if block.startswith("Args:"):
|
||||
args_block = block
|
||||
break
|
||||
if block.startswith(("Returns:", "Example:")):
|
||||
# Don't break in case Args come after
|
||||
past_descriptors = True
|
||||
elif not past_descriptors:
|
||||
descriptors.append(block)
|
||||
else:
|
||||
continue
|
||||
description = " ".join(descriptors)
|
||||
else:
|
||||
if error_on_invalid_docstring:
|
||||
msg = "Found invalid Google-Style docstring."
|
||||
raise ValueError(msg)
|
||||
description = ""
|
||||
args_block = None
|
||||
arg_descriptions = {}
|
||||
if args_block:
|
||||
arg = None
|
||||
for line in args_block.split("\n")[1:]:
|
||||
if ":" in line:
|
||||
arg, desc = line.split(":", maxsplit=1)
|
||||
arg = arg.strip()
|
||||
arg_name, _, annotations_ = arg.partition(" ")
|
||||
if annotations_.startswith("(") and annotations_.endswith(")"):
|
||||
arg = arg_name
|
||||
arg_descriptions[arg] = desc.strip()
|
||||
elif arg:
|
||||
arg_descriptions[arg] += " " + line.strip()
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _convert_any_typed_dicts_to_pydantic(
|
||||
type_: type,
|
||||
*,
|
||||
@ -282,18 +341,28 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||
annotated_args[0], depth=depth + 1, visited=visited
|
||||
)
|
||||
field_kwargs = dict(zip(("default", "description"), annotated_args[1:]))
|
||||
field_kwargs = {}
|
||||
metadata = annotated_args[1:]
|
||||
if len(metadata) == 1 and isinstance(metadata[0], str):
|
||||
# Case: Annotated[int, "a description"]
|
||||
field_kwargs["description"] = metadata[0]
|
||||
elif len(metadata) > 0:
|
||||
# Case: Annotated[int, default_val, "a description"]
|
||||
field_kwargs["default"] = metadata[0]
|
||||
if len(metadata) > 1 and isinstance(metadata[1], str):
|
||||
field_kwargs["description"] = metadata[1]
|
||||
|
||||
if (field_desc := field_kwargs.get("description")) and not isinstance(
|
||||
field_desc, str
|
||||
):
|
||||
msg = (
|
||||
f"Invalid annotation for field {arg}. Third argument to "
|
||||
f"Annotated must be a string description, received value of "
|
||||
f"type {type(field_desc)}."
|
||||
f"Invalid annotation for field {arg}. "
|
||||
"Description must be a string."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if arg_desc := arg_descriptions.get(arg):
|
||||
field_kwargs["description"] = arg_desc
|
||||
|
||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
||||
else:
|
||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||
@ -317,6 +386,25 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
return type_
|
||||
|
||||
|
||||
def _py_38_safe_origin(origin: type) -> type:
|
||||
origin_union_type_map: dict[type, Any] = (
|
||||
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
|
||||
)
|
||||
|
||||
origin_map: dict[type, Any] = {
|
||||
dict: dict,
|
||||
list: list,
|
||||
tuple: tuple,
|
||||
set: set,
|
||||
collections.abc.Iterable: typing.Iterable,
|
||||
collections.abc.Mapping: typing.Mapping,
|
||||
collections.abc.Sequence: typing.Sequence,
|
||||
collections.abc.MutableMapping: typing.MutableMapping,
|
||||
**origin_union_type_map,
|
||||
}
|
||||
return cast("type", origin_map.get(origin, origin))
|
||||
|
||||
|
||||
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tool into the OpenAI function API.
|
||||
|
||||
@ -386,6 +474,30 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
||||
return {"type": "function", "function": function}
|
||||
|
||||
|
||||
def _recursive_set_additional_properties_false(
|
||||
schema: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(schema, dict):
|
||||
# Check if 'required' is a key at the current level or if the schema is empty,
|
||||
# in which case additionalProperties still needs to be specified.
|
||||
if "required" in schema or (
|
||||
"properties" in schema and not schema["properties"]
|
||||
):
|
||||
schema["additionalProperties"] = False
|
||||
|
||||
# Recursively check 'properties' and 'items' if they exist
|
||||
if "anyOf" in schema:
|
||||
for sub_schema in schema["anyOf"]:
|
||||
_recursive_set_additional_properties_false(sub_schema)
|
||||
if "properties" in schema:
|
||||
for sub_schema in schema["properties"].values():
|
||||
_recursive_set_additional_properties_false(sub_schema)
|
||||
if "items" in schema:
|
||||
_recursive_set_additional_properties_false(schema["items"])
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def convert_to_openai_function(
|
||||
function: Union[dict[str, Any], type, Callable, BaseTool],
|
||||
*,
|
||||
@ -716,105 +828,3 @@ def tool_example_to_messages(
|
||||
if ai_response:
|
||||
messages.append(AIMessage(content=ai_response))
|
||||
return messages
|
||||
|
||||
|
||||
def _parse_google_docstring(
|
||||
docstring: Optional[str],
|
||||
args: list[str],
|
||||
*,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
"""
|
||||
if docstring:
|
||||
docstring_blocks = docstring.split("\n\n")
|
||||
if error_on_invalid_docstring:
|
||||
filtered_annotations = {
|
||||
arg for arg in args if arg not in {"run_manager", "callbacks", "return"}
|
||||
}
|
||||
if filtered_annotations and (
|
||||
len(docstring_blocks) < 2
|
||||
or not any(block.startswith("Args:") for block in docstring_blocks[1:])
|
||||
):
|
||||
msg = "Found invalid Google-Style docstring."
|
||||
raise ValueError(msg)
|
||||
descriptors = []
|
||||
args_block = None
|
||||
past_descriptors = False
|
||||
for block in docstring_blocks:
|
||||
if block.startswith("Args:"):
|
||||
args_block = block
|
||||
break
|
||||
if block.startswith(("Returns:", "Example:")):
|
||||
# Don't break in case Args come after
|
||||
past_descriptors = True
|
||||
elif not past_descriptors:
|
||||
descriptors.append(block)
|
||||
else:
|
||||
continue
|
||||
description = " ".join(descriptors)
|
||||
else:
|
||||
if error_on_invalid_docstring:
|
||||
msg = "Found invalid Google-Style docstring."
|
||||
raise ValueError(msg)
|
||||
description = ""
|
||||
args_block = None
|
||||
arg_descriptions = {}
|
||||
if args_block:
|
||||
arg = None
|
||||
for line in args_block.split("\n")[1:]:
|
||||
if ":" in line:
|
||||
arg, desc = line.split(":", maxsplit=1)
|
||||
arg = arg.strip()
|
||||
arg_name, _, annotations_ = arg.partition(" ")
|
||||
if annotations_.startswith("(") and annotations_.endswith(")"):
|
||||
arg = arg_name
|
||||
arg_descriptions[arg] = desc.strip()
|
||||
elif arg:
|
||||
arg_descriptions[arg] += " " + line.strip()
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _py_38_safe_origin(origin: type) -> type:
|
||||
origin_union_type_map: dict[type, Any] = (
|
||||
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
|
||||
)
|
||||
|
||||
origin_map: dict[type, Any] = {
|
||||
dict: dict,
|
||||
list: list,
|
||||
tuple: tuple,
|
||||
set: set,
|
||||
collections.abc.Iterable: typing.Iterable,
|
||||
collections.abc.Mapping: typing.Mapping,
|
||||
collections.abc.Sequence: typing.Sequence,
|
||||
collections.abc.MutableMapping: typing.MutableMapping,
|
||||
**origin_union_type_map,
|
||||
}
|
||||
return cast("type", origin_map.get(origin, origin))
|
||||
|
||||
|
||||
def _recursive_set_additional_properties_false(
|
||||
schema: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(schema, dict):
|
||||
# Check if 'required' is a key at the current level or if the schema is empty,
|
||||
# in which case additionalProperties still needs to be specified.
|
||||
if "required" in schema or (
|
||||
"properties" in schema and not schema["properties"]
|
||||
):
|
||||
schema["additionalProperties"] = False
|
||||
|
||||
# Recursively check 'properties' and 'items' if they exist
|
||||
if "anyOf" in schema:
|
||||
for sub_schema in schema["anyOf"]:
|
||||
_recursive_set_additional_properties_false(sub_schema)
|
||||
if "properties" in schema:
|
||||
for sub_schema in schema["properties"].values():
|
||||
_recursive_set_additional_properties_false(sub_schema)
|
||||
if "items" in schema:
|
||||
_recursive_set_additional_properties_false(schema["items"])
|
||||
|
||||
return schema
|
||||
|
@ -35,6 +35,17 @@ from langchain_core.utils.function_calling import (
|
||||
)
|
||||
|
||||
|
||||
def remove_titles(obj: dict) -> None:
|
||||
if isinstance(obj, dict):
|
||||
obj.pop("title", None)
|
||||
for v in obj.values():
|
||||
remove_titles(v)
|
||||
elif isinstance(obj, list):
|
||||
for v in obj:
|
||||
remove_titles(v)
|
||||
return obj
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel): # noqa: N801
|
||||
@ -365,9 +376,9 @@ def test_convert_to_openai_function(
|
||||
dummy_extensions_typed_dict_docstring,
|
||||
):
|
||||
actual = convert_to_openai_function(fn)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
# Test runnables
|
||||
actual = convert_to_openai_function(runnable.as_tool(description="Dummy function."))
|
||||
parameters = {
|
||||
"type": "object",
|
||||
@ -384,7 +395,6 @@ def test_convert_to_openai_function(
|
||||
runnable_expected["parameters"] = parameters
|
||||
assert actual == runnable_expected
|
||||
|
||||
# Test simple Tool
|
||||
def my_function(_: str) -> str:
|
||||
return ""
|
||||
|
||||
@ -398,11 +408,12 @@ def test_convert_to_openai_function(
|
||||
"name": "dummy_function",
|
||||
"description": "test description",
|
||||
"parameters": {
|
||||
"properties": {"__arg1": {"title": "__arg1", "type": "string"}},
|
||||
"properties": {"__arg1": {"type": "string"}},
|
||||
"required": ["__arg1"],
|
||||
"type": "object",
|
||||
},
|
||||
}
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -454,6 +465,7 @@ def test_convert_to_openai_function_nested() -> None:
|
||||
}
|
||||
|
||||
actual = convert_to_openai_function(my_function)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -494,6 +506,7 @@ def test_convert_to_openai_function_nested_strict() -> None:
|
||||
}
|
||||
|
||||
actual = convert_to_openai_function(my_function, strict=True)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -518,23 +531,20 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None:
|
||||
"my_arg": {
|
||||
"anyOf": [
|
||||
{
|
||||
"properties": {"foo": {"title": "Foo", "type": "string"}},
|
||||
"properties": {"foo": {"type": "string"}},
|
||||
"required": ["foo"],
|
||||
"title": "NestedA",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
},
|
||||
{
|
||||
"properties": {"bar": {"title": "Bar", "type": "integer"}},
|
||||
"properties": {"bar": {"type": "integer"}},
|
||||
"required": ["bar"],
|
||||
"title": "NestedB",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
},
|
||||
{
|
||||
"properties": {"baz": {"title": "Baz", "type": "boolean"}},
|
||||
"properties": {"baz": {"type": "boolean"}},
|
||||
"required": ["baz"],
|
||||
"title": "NestedC",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
},
|
||||
@ -549,6 +559,7 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None:
|
||||
}
|
||||
|
||||
actual = convert_to_openai_function(my_function, strict=True)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -556,7 +567,6 @@ json_schema_no_description_no_params = {
|
||||
"title": "dummy_function",
|
||||
}
|
||||
|
||||
|
||||
json_schema_no_description = {
|
||||
"title": "dummy_function",
|
||||
"type": "object",
|
||||
@ -571,7 +581,6 @@ json_schema_no_description = {
|
||||
"required": ["arg1", "arg2"],
|
||||
}
|
||||
|
||||
|
||||
anthropic_tool_no_description = {
|
||||
"name": "dummy_function",
|
||||
"input_schema": {
|
||||
@ -588,7 +597,6 @@ anthropic_tool_no_description = {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
bedrock_converse_tool_no_description = {
|
||||
"toolSpec": {
|
||||
"name": "dummy_function",
|
||||
@ -609,7 +617,6 @@ bedrock_converse_tool_no_description = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
openai_function_no_description = {
|
||||
"name": "dummy_function",
|
||||
"parameters": {
|
||||
@ -626,7 +633,6 @@ openai_function_no_description = {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
openai_function_no_description_no_params = {
|
||||
"name": "dummy_function",
|
||||
}
|
||||
@ -658,6 +664,7 @@ def test_convert_to_openai_function_no_description(func: dict) -> None:
|
||||
},
|
||||
}
|
||||
actual = convert_to_openai_function(func)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -772,7 +779,6 @@ def test_tool_outputs() -> None:
|
||||
]
|
||||
assert messages[2].content == "Output1"
|
||||
|
||||
# Test final AI response
|
||||
messages = tool_example_to_messages(
|
||||
input="This is an example",
|
||||
tool_calls=[
|
||||
@ -880,12 +886,10 @@ def test__convert_typed_dict_to_openai_function(
|
||||
"items": [
|
||||
{"type": "array", "items": {}},
|
||||
{
|
||||
"title": "SubTool",
|
||||
"description": "Subtool docstring.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"args": {
|
||||
"title": "Args",
|
||||
"description": "this does bar",
|
||||
"default": {},
|
||||
"type": "object",
|
||||
@ -916,12 +920,10 @@ def test__convert_typed_dict_to_openai_function(
|
||||
"maxItems": 1,
|
||||
"items": [
|
||||
{
|
||||
"title": "SubTool",
|
||||
"description": "Subtool docstring.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"args": {
|
||||
"title": "Args",
|
||||
"description": "this does bar",
|
||||
"default": {},
|
||||
"type": "object",
|
||||
@ -1034,6 +1036,7 @@ def test__convert_typed_dict_to_openai_function(
|
||||
},
|
||||
}
|
||||
actual = _convert_typed_dict_to_openai_function(Tool)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -1042,7 +1045,6 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
|
||||
class Tool(typed_dict): # type: ignore[misc]
|
||||
arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not.
|
||||
|
||||
# Error should be raised since we're using v1 code path here
|
||||
with pytest.raises(TypeError):
|
||||
_convert_typed_dict_to_openai_function(Tool)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user