mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 05:45:01 +00:00
x
This commit is contained in:
parent
d3930f9906
commit
b9ebae37a4
@ -384,6 +384,10 @@ class ToolException(Exception): # noqa: N818
|
|||||||
|
|
||||||
|
|
||||||
ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
|
ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
|
||||||
|
"""Used to define the schema for tool arguments.
|
||||||
|
|
||||||
|
This schema should not include injected arguments like `InjectedToolCallId`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
||||||
@ -575,11 +579,17 @@ class ChildTool(BaseTool):
|
|||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""The tool's input schema.
|
"""The tool's input schema.
|
||||||
|
|
||||||
|
All required arguments will be included as part of this schema, including
|
||||||
|
any injected arguments like `InjectedToolCallId`.
|
||||||
|
|
||||||
|
A schema without injected arguments can be obtained using the
|
||||||
|
`tool_call_schema` property instead.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: The configuration for the tool.
|
config: The configuration for the tool.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The input schema for the tool.
|
The full input schema for the tool, including injected arguments.
|
||||||
"""
|
"""
|
||||||
if self.args_schema is not None:
|
if self.args_schema is not None:
|
||||||
if isinstance(self.args_schema, dict):
|
if isinstance(self.args_schema, dict):
|
||||||
@ -626,31 +636,29 @@ class ChildTool(BaseTool):
|
|||||||
InjectedToolCallId is required but not provided.
|
InjectedToolCallId is required but not provided.
|
||||||
NotImplementedError: If args_schema is not a supported type.
|
NotImplementedError: If args_schema is not a supported type.
|
||||||
"""
|
"""
|
||||||
input_args = self.args_schema
|
args_schema = self.args_schema
|
||||||
if isinstance(tool_input, str):
|
if isinstance(tool_input, str):
|
||||||
if input_args is not None:
|
if args_schema is not None:
|
||||||
if isinstance(input_args, dict):
|
if isinstance(args_schema, dict):
|
||||||
msg = (
|
msg = (
|
||||||
"String tool inputs are not allowed when "
|
"String tool inputs are not allowed when "
|
||||||
"using tools with JSON schema args_schema."
|
"using tools with JSON schema args_schema."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
key_ = next(iter(get_fields(input_args).keys()))
|
key_ = next(iter(get_fields(args_schema).keys()))
|
||||||
if issubclass(input_args, BaseModel):
|
if issubclass(args_schema, BaseModel):
|
||||||
input_args.model_validate({key_: tool_input})
|
args_schema.model_validate({key_: tool_input})
|
||||||
elif issubclass(input_args, BaseModelV1):
|
elif issubclass(args_schema, BaseModelV1):
|
||||||
input_args.parse_obj({key_: tool_input})
|
args_schema.parse_obj({key_: tool_input})
|
||||||
else:
|
else:
|
||||||
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
|
msg = f"args_schema must be a Pydantic BaseModel, got {args_schema}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
return tool_input
|
return tool_input
|
||||||
if input_args is not None:
|
if args_schema is not None:
|
||||||
if isinstance(input_args, dict):
|
if isinstance(args_schema, dict):
|
||||||
return tool_input
|
return tool_input
|
||||||
if issubclass(input_args, BaseModel):
|
if issubclass(args_schema, BaseModel):
|
||||||
# Check args_schema for injected arguments
|
for k, v in get_all_basemodel_annotations(args_schema).items():
|
||||||
schema_annotations = get_all_basemodel_annotations(input_args)
|
|
||||||
for k, v in schema_annotations.items():
|
|
||||||
if (
|
if (
|
||||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||||
and k not in tool_input
|
and k not in tool_input
|
||||||
@ -665,58 +673,10 @@ class ChildTool(BaseTool):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
tool_input[k] = tool_call_id
|
tool_input[k] = tool_call_id
|
||||||
|
result = args_schema.model_validate(tool_input)
|
||||||
# Also check function signature for injected arguments not in schema
|
|
||||||
func_to_check = None
|
|
||||||
if hasattr(self, 'func') and self.func is not None:
|
|
||||||
# For StructuredTool, check the original function
|
|
||||||
func_to_check = self.func
|
|
||||||
elif hasattr(self, 'coroutine') and self.coroutine is not None:
|
|
||||||
# For async StructuredTool, check the original coroutine
|
|
||||||
func_to_check = self.coroutine
|
|
||||||
else:
|
|
||||||
# For other tools, check the _run method
|
|
||||||
func_to_check = self._run
|
|
||||||
|
|
||||||
if func_to_check:
|
|
||||||
try:
|
|
||||||
func_annotations = get_type_hints(func_to_check, include_extras=True)
|
|
||||||
for param_name, param_type in func_annotations.items():
|
|
||||||
# Check for InjectedToolCallId specifically
|
|
||||||
if (
|
|
||||||
_is_injected_arg_type(param_type, injected_type=InjectedToolCallId)
|
|
||||||
and param_name not in schema_annotations
|
|
||||||
and param_name not in tool_input
|
|
||||||
):
|
|
||||||
if tool_call_id is None:
|
|
||||||
msg = (
|
|
||||||
"When tool includes an InjectedToolCallId "
|
|
||||||
"argument, tool must always be invoked with a full "
|
|
||||||
"model ToolCall of the form: {'args': {...}, "
|
|
||||||
"'name': '...', 'type': 'tool_call', "
|
|
||||||
"'tool_call_id': '...'}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
tool_input[param_name] = tool_call_id
|
|
||||||
# Check for any other injected arguments (for extensibility)
|
|
||||||
elif (
|
|
||||||
_is_injected_arg_type(param_type)
|
|
||||||
and param_name not in schema_annotations
|
|
||||||
and param_name not in tool_input
|
|
||||||
):
|
|
||||||
# For other injected types, we don't inject values automatically
|
|
||||||
# but we allow them to be missing from validation
|
|
||||||
pass
|
|
||||||
except (TypeError, AttributeError):
|
|
||||||
# Handle cases where get_type_hints fails
|
|
||||||
pass
|
|
||||||
|
|
||||||
result = input_args.model_validate(tool_input)
|
|
||||||
result_dict = result.model_dump()
|
result_dict = result.model_dump()
|
||||||
elif issubclass(input_args, BaseModelV1):
|
elif issubclass(args_schema, BaseModelV1):
|
||||||
# Check args_schema for injected arguments
|
for k, v in get_all_basemodel_annotations(args_schema).items():
|
||||||
schema_annotations = get_all_basemodel_annotations(input_args)
|
|
||||||
for k, v in schema_annotations.items():
|
|
||||||
if (
|
if (
|
||||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||||
and k not in tool_input
|
and k not in tool_input
|
||||||
@ -731,53 +691,7 @@ class ChildTool(BaseTool):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
tool_input[k] = tool_call_id
|
tool_input[k] = tool_call_id
|
||||||
|
result = args_schema.parse_obj(tool_input)
|
||||||
# Also check function signature for injected arguments not in schema
|
|
||||||
func_to_check = None
|
|
||||||
if hasattr(self, 'func') and self.func is not None:
|
|
||||||
# For StructuredTool, check the original function
|
|
||||||
func_to_check = self.func
|
|
||||||
elif hasattr(self, 'coroutine') and self.coroutine is not None:
|
|
||||||
# For async StructuredTool, check the original coroutine
|
|
||||||
func_to_check = self.coroutine
|
|
||||||
else:
|
|
||||||
# For other tools, check the _run method
|
|
||||||
func_to_check = self._run
|
|
||||||
|
|
||||||
if func_to_check:
|
|
||||||
try:
|
|
||||||
func_annotations = get_type_hints(func_to_check, include_extras=True)
|
|
||||||
for param_name, param_type in func_annotations.items():
|
|
||||||
# Check for InjectedToolCallId specifically
|
|
||||||
if (
|
|
||||||
_is_injected_arg_type(param_type, injected_type=InjectedToolCallId)
|
|
||||||
and param_name not in schema_annotations
|
|
||||||
and param_name not in tool_input
|
|
||||||
):
|
|
||||||
if tool_call_id is None:
|
|
||||||
msg = (
|
|
||||||
"When tool includes an InjectedToolCallId "
|
|
||||||
"argument, tool must always be invoked with a full "
|
|
||||||
"model ToolCall of the form: {'args': {...}, "
|
|
||||||
"'name': '...', 'type': 'tool_call', "
|
|
||||||
"'tool_call_id': '...'}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
tool_input[param_name] = tool_call_id
|
|
||||||
# Check for any other injected arguments (for extensibility)
|
|
||||||
elif (
|
|
||||||
_is_injected_arg_type(param_type)
|
|
||||||
and param_name not in schema_annotations
|
|
||||||
and param_name not in tool_input
|
|
||||||
):
|
|
||||||
# For other injected types, we don't inject values automatically
|
|
||||||
# but we allow them to be missing from validation
|
|
||||||
pass
|
|
||||||
except (TypeError, AttributeError):
|
|
||||||
# Handle cases where get_type_hints fails
|
|
||||||
pass
|
|
||||||
|
|
||||||
result = input_args.parse_obj(tool_input)
|
|
||||||
result_dict = result.dict()
|
result_dict = result.dict()
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -2682,26 +2682,11 @@ def test_tool_args_schema_with_annotated_type() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_injected_args_exclusion_from_schema_bug() -> None:
|
def test_args_schema_without_injected_arguments() -> None:
|
||||||
"""Test that demonstrates the bug where injected args must be in args_schema.
|
"""Test that injected args work correctly in tool schemas."""
|
||||||
|
|
||||||
ISSUE: InjectedToolCallId parameters must be explicitly included in the
|
class CustomSchema(BaseModel):
|
||||||
tool's args_schema for injection to work, even though they should be
|
"""Schema for LLM inputs."""
|
||||||
automatically excluded from the LLM schema.
|
|
||||||
|
|
||||||
EXPECTED BEHAVIOR:
|
|
||||||
- InjectedToolCallId should work without being in args_schema
|
|
||||||
- These parameters should be automatically injected by the tool framework
|
|
||||||
- The LLM should never see these parameters in the tool schema
|
|
||||||
|
|
||||||
ACTUAL BEHAVIOR:
|
|
||||||
- InjectedToolCallId must be included in args_schema
|
|
||||||
- Without them in args_schema, injection fails with missing parameter error
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Test case 1: Tool with injected arg NOT in args_schema (should work but currently fails)
|
|
||||||
class BrokenToolSchema(BaseModel):
|
|
||||||
"""This should work but currently doesn't - injection fails."""
|
|
||||||
|
|
||||||
text: str = Field(description="Text to process")
|
text: str = Field(description="Text to process")
|
||||||
# InjectedToolCallId should NOT need to be here
|
# InjectedToolCallId should NOT need to be here
|
||||||
@ -2709,13 +2694,14 @@ def test_injected_args_exclusion_from_schema_bug() -> None:
|
|||||||
@tool(
|
@tool(
|
||||||
"broken_tool",
|
"broken_tool",
|
||||||
description="Tool that should work with automatic injection",
|
description="Tool that should work with automatic injection",
|
||||||
args_schema=BrokenToolSchema,
|
parse_docstring=True,
|
||||||
|
args_schema=CustomSchema,
|
||||||
)
|
)
|
||||||
def broken_tool(
|
def broken_tool(
|
||||||
text: str,
|
text: str,
|
||||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""This tool should work but currently fails because injected params aren't in args_schema."""
|
"""Tool with injected arguments."""
|
||||||
return f"Processed '{text}' (call_id: {tool_call_id})"
|
return f"Processed '{text}' (call_id: {tool_call_id})"
|
||||||
|
|
||||||
# Test tool call structure
|
# Test tool call structure
|
||||||
@ -2726,31 +2712,45 @@ def test_injected_args_exclusion_from_schema_bug() -> None:
|
|||||||
"type": "tool_call",
|
"type": "tool_call",
|
||||||
}
|
}
|
||||||
|
|
||||||
# The broken tool should work but currently fails
|
result = broken_tool.invoke(tool_call)
|
||||||
with pytest.raises(ValidationError, match="1 validation error"):
|
assert isinstance(result, ToolMessage)
|
||||||
# This should work: the tool_call_id should be automatically injected
|
assert result.content == "Processed 'test data' (call_id: test_123)"
|
||||||
# from the tool call structure, but currently fails because tool_call_id
|
|
||||||
# is not in the args_schema
|
|
||||||
broken_tool.invoke(tool_call)
|
|
||||||
|
|
||||||
# Verify that tool_call_schema correctly excludes injected args for both tools
|
|
||||||
broken_schema = _get_tool_call_json_schema(broken_tool)
|
|
||||||
|
|
||||||
# Both should have the same tool call schema (excluding injected args)
|
|
||||||
expected_schema_properties = {
|
|
||||||
"text": {"description": "Text to process", "title": "Text", "type": "string"}
|
|
||||||
}
|
|
||||||
assert broken_schema["properties"] == expected_schema_properties
|
|
||||||
|
|
||||||
# Neither schema should include tool_call_id in the tool call schema
|
|
||||||
assert "tool_call_id" not in broken_schema["properties"]
|
|
||||||
|
|
||||||
# But the full input schema should include tool_call_id for validation
|
# But the full input schema should include tool_call_id for validation
|
||||||
broken_input_schema = _schema(broken_tool.get_input_schema())
|
schema = _schema(broken_tool.get_input_schema())
|
||||||
|
assert schema == {
|
||||||
|
"description": "Tool with injected arguments.",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"title": "Text",
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to process",
|
||||||
|
},
|
||||||
|
# Tool call ID **is** required in this case.
|
||||||
|
"tool_call_id": {
|
||||||
|
"title": "Tool Call Id",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"text",
|
||||||
|
"tool_call_id",
|
||||||
|
],
|
||||||
|
"title": "broken_tool",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
|
||||||
# The broken tool's input schema should also include tool_call_id for validation,
|
tool_call_schema = _schema(broken_tool.tool_call_schema)
|
||||||
# but this is where the bug manifests - it's missing
|
assert tool_call_schema == {
|
||||||
# This assertion documents the current buggy behavior:
|
"description": "Tool that should work with automatic injection",
|
||||||
assert (
|
"properties": {
|
||||||
"tool_call_id" not in broken_input_schema["properties"]
|
"text": {
|
||||||
) # BUG: should be present
|
"description": "Text to process",
|
||||||
|
"title": "Text",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text"],
|
||||||
|
"title": "broken_tool",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user