mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 21:08:59 +00:00
x
This commit is contained in:
parent
7ff405077d
commit
d3930f9906
@ -648,7 +648,9 @@ class ChildTool(BaseTool):
|
||||
if isinstance(input_args, dict):
|
||||
return tool_input
|
||||
if issubclass(input_args, BaseModel):
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
# Check args_schema for injected arguments
|
||||
schema_annotations = get_all_basemodel_annotations(input_args)
|
||||
for k, v in schema_annotations.items():
|
||||
if (
|
||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||
and k not in tool_input
|
||||
@ -663,10 +665,58 @@ class ChildTool(BaseTool):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
|
||||
# Also check function signature for injected arguments not in schema
|
||||
func_to_check = None
|
||||
if hasattr(self, 'func') and self.func is not None:
|
||||
# For StructuredTool, check the original function
|
||||
func_to_check = self.func
|
||||
elif hasattr(self, 'coroutine') and self.coroutine is not None:
|
||||
# For async StructuredTool, check the original coroutine
|
||||
func_to_check = self.coroutine
|
||||
else:
|
||||
# For other tools, check the _run method
|
||||
func_to_check = self._run
|
||||
|
||||
if func_to_check:
|
||||
try:
|
||||
func_annotations = get_type_hints(func_to_check, include_extras=True)
|
||||
for param_name, param_type in func_annotations.items():
|
||||
# Check for InjectedToolCallId specifically
|
||||
if (
|
||||
_is_injected_arg_type(param_type, injected_type=InjectedToolCallId)
|
||||
and param_name not in schema_annotations
|
||||
and param_name not in tool_input
|
||||
):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[param_name] = tool_call_id
|
||||
# Check for any other injected arguments (for extensibility)
|
||||
elif (
|
||||
_is_injected_arg_type(param_type)
|
||||
and param_name not in schema_annotations
|
||||
and param_name not in tool_input
|
||||
):
|
||||
# For other injected types, we don't inject values automatically
|
||||
# but we allow them to be missing from validation
|
||||
pass
|
||||
except (TypeError, AttributeError):
|
||||
# Handle cases where get_type_hints fails
|
||||
pass
|
||||
|
||||
result = input_args.model_validate(tool_input)
|
||||
result_dict = result.model_dump()
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
# Check args_schema for injected arguments
|
||||
schema_annotations = get_all_basemodel_annotations(input_args)
|
||||
for k, v in schema_annotations.items():
|
||||
if (
|
||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||
and k not in tool_input
|
||||
@ -681,6 +731,52 @@ class ChildTool(BaseTool):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
|
||||
# Also check function signature for injected arguments not in schema
|
||||
func_to_check = None
|
||||
if hasattr(self, 'func') and self.func is not None:
|
||||
# For StructuredTool, check the original function
|
||||
func_to_check = self.func
|
||||
elif hasattr(self, 'coroutine') and self.coroutine is not None:
|
||||
# For async StructuredTool, check the original coroutine
|
||||
func_to_check = self.coroutine
|
||||
else:
|
||||
# For other tools, check the _run method
|
||||
func_to_check = self._run
|
||||
|
||||
if func_to_check:
|
||||
try:
|
||||
func_annotations = get_type_hints(func_to_check, include_extras=True)
|
||||
for param_name, param_type in func_annotations.items():
|
||||
# Check for InjectedToolCallId specifically
|
||||
if (
|
||||
_is_injected_arg_type(param_type, injected_type=InjectedToolCallId)
|
||||
and param_name not in schema_annotations
|
||||
and param_name not in tool_input
|
||||
):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[param_name] = tool_call_id
|
||||
# Check for any other injected arguments (for extensibility)
|
||||
elif (
|
||||
_is_injected_arg_type(param_type)
|
||||
and param_name not in schema_annotations
|
||||
and param_name not in tool_input
|
||||
):
|
||||
# For other injected types, we don't inject values automatically
|
||||
# but we allow them to be missing from validation
|
||||
pass
|
||||
except (TypeError, AttributeError):
|
||||
# Handle cases where get_type_hints fails
|
||||
pass
|
||||
|
||||
result = input_args.parse_obj(tool_input)
|
||||
result_dict = result.dict()
|
||||
else:
|
||||
|
@ -2680,3 +2680,77 @@ def test_tool_args_schema_with_annotated_type() -> None:
|
||||
"type": "array",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_injected_args_exclusion_from_schema_bug() -> None:
|
||||
"""Test that demonstrates the bug where injected args must be in args_schema.
|
||||
|
||||
ISSUE: InjectedToolCallId parameters must be explicitly included in the
|
||||
tool's args_schema for injection to work, even though they should be
|
||||
automatically excluded from the LLM schema.
|
||||
|
||||
EXPECTED BEHAVIOR:
|
||||
- InjectedToolCallId should work without being in args_schema
|
||||
- These parameters should be automatically injected by the tool framework
|
||||
- The LLM should never see these parameters in the tool schema
|
||||
|
||||
ACTUAL BEHAVIOR:
|
||||
- InjectedToolCallId must be included in args_schema
|
||||
- Without them in args_schema, injection fails with missing parameter error
|
||||
"""
|
||||
|
||||
# Test case 1: Tool with injected arg NOT in args_schema (should work but currently fails)
|
||||
class BrokenToolSchema(BaseModel):
|
||||
"""This should work but currently doesn't - injection fails."""
|
||||
|
||||
text: str = Field(description="Text to process")
|
||||
# InjectedToolCallId should NOT need to be here
|
||||
|
||||
@tool(
|
||||
"broken_tool",
|
||||
description="Tool that should work with automatic injection",
|
||||
args_schema=BrokenToolSchema,
|
||||
)
|
||||
def broken_tool(
|
||||
text: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> str:
|
||||
"""This tool should work but currently fails because injected params aren't in args_schema."""
|
||||
return f"Processed '{text}' (call_id: {tool_call_id})"
|
||||
|
||||
# Test tool call structure
|
||||
tool_call = {
|
||||
"name": "test_tool",
|
||||
"args": {"text": "test data"},
|
||||
"id": "test_123",
|
||||
"type": "tool_call",
|
||||
}
|
||||
|
||||
# The broken tool should work but currently fails
|
||||
with pytest.raises(ValidationError, match="1 validation error"):
|
||||
# This should work: the tool_call_id should be automatically injected
|
||||
# from the tool call structure, but currently fails because tool_call_id
|
||||
# is not in the args_schema
|
||||
broken_tool.invoke(tool_call)
|
||||
|
||||
# Verify that tool_call_schema correctly excludes injected args for both tools
|
||||
broken_schema = _get_tool_call_json_schema(broken_tool)
|
||||
|
||||
# Both should have the same tool call schema (excluding injected args)
|
||||
expected_schema_properties = {
|
||||
"text": {"description": "Text to process", "title": "Text", "type": "string"}
|
||||
}
|
||||
assert broken_schema["properties"] == expected_schema_properties
|
||||
|
||||
# Neither schema should include tool_call_id in the tool call schema
|
||||
assert "tool_call_id" not in broken_schema["properties"]
|
||||
|
||||
# But the full input schema should include tool_call_id for validation
|
||||
broken_input_schema = _schema(broken_tool.get_input_schema())
|
||||
|
||||
# The broken tool's input schema should also include tool_call_id for validation,
|
||||
# but this is where the bug manifests - it's missing
|
||||
# This assertion documents the current buggy behavior:
|
||||
assert (
|
||||
"tool_call_id" not in broken_input_schema["properties"]
|
||||
) # BUG: should be present
|
||||
|
Loading…
Reference in New Issue
Block a user