diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 3caf190ef81..f93c21d5697 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -648,7 +648,9 @@ class ChildTool(BaseTool): if isinstance(input_args, dict): return tool_input if issubclass(input_args, BaseModel): - for k, v in get_all_basemodel_annotations(input_args).items(): + # Check args_schema for injected arguments + schema_annotations = get_all_basemodel_annotations(input_args) + for k, v in schema_annotations.items(): if ( _is_injected_arg_type(v, injected_type=InjectedToolCallId) and k not in tool_input @@ -663,10 +665,58 @@ class ChildTool(BaseTool): ) raise ValueError(msg) tool_input[k] = tool_call_id + + # Also check function signature for injected arguments not in schema + func_to_check = None + if hasattr(self, 'func') and self.func is not None: + # For StructuredTool, check the original function + func_to_check = self.func + elif hasattr(self, 'coroutine') and self.coroutine is not None: + # For async StructuredTool, check the original coroutine + func_to_check = self.coroutine + else: + # For other tools, check the _run method + func_to_check = self._run + + if func_to_check: + try: + func_annotations = get_type_hints(func_to_check, include_extras=True) + for param_name, param_type in func_annotations.items(): + # Check for InjectedToolCallId specifically + if ( + _is_injected_arg_type(param_type, injected_type=InjectedToolCallId) + and param_name not in schema_annotations + and param_name not in tool_input + ): + if tool_call_id is None: + msg = ( + "When tool includes an InjectedToolCallId " + "argument, tool must always be invoked with a full " + "model ToolCall of the form: {'args': {...}, " + "'name': '...', 'type': 'tool_call', " + "'tool_call_id': '...'}" + ) + raise ValueError(msg) + tool_input[param_name] = tool_call_id + # Check for any other injected arguments (for extensibility) + elif ( + _is_injected_arg_type(param_type) + and param_name not in schema_annotations + and param_name not in tool_input + ): + # For other injected types, we don't inject values automatically + # but we allow them to be missing from validation + pass + except (TypeError, AttributeError): + # Handle cases where get_type_hints fails + pass + result = input_args.model_validate(tool_input) result_dict = result.model_dump() elif issubclass(input_args, BaseModelV1): - for k, v in get_all_basemodel_annotations(input_args).items(): + # Check args_schema for injected arguments + schema_annotations = get_all_basemodel_annotations(input_args) + for k, v in schema_annotations.items(): if ( _is_injected_arg_type(v, injected_type=InjectedToolCallId) and k not in tool_input @@ -681,6 +731,52 @@ class ChildTool(BaseTool): ) raise ValueError(msg) tool_input[k] = tool_call_id + + # Also check function signature for injected arguments not in schema + func_to_check = None + if hasattr(self, 'func') and self.func is not None: + # For StructuredTool, check the original function + func_to_check = self.func + elif hasattr(self, 'coroutine') and self.coroutine is not None: + # For async StructuredTool, check the original coroutine + func_to_check = self.coroutine + else: + # For other tools, check the _run method + func_to_check = self._run + + if func_to_check: + try: + func_annotations = get_type_hints(func_to_check, include_extras=True) + for param_name, param_type in func_annotations.items(): + # Check for InjectedToolCallId specifically + if ( + _is_injected_arg_type(param_type, injected_type=InjectedToolCallId) + and param_name not in schema_annotations + and param_name not in tool_input + ): + if tool_call_id is None: + msg = ( + "When tool includes an InjectedToolCallId " + "argument, tool must always be invoked with a full " + "model ToolCall of the form: {'args': {...}, " + "'name': '...', 'type': 'tool_call', " + "'tool_call_id': '...'}" + ) + raise ValueError(msg) + tool_input[param_name] = tool_call_id + # Check for any other injected arguments (for extensibility) + elif ( + _is_injected_arg_type(param_type) + and param_name not in schema_annotations + and param_name not in tool_input + ): + # For other injected types, we don't inject values automatically + # but we allow them to be missing from validation + pass + except (TypeError, AttributeError): + # Handle cases where get_type_hints fails + pass + result = input_args.parse_obj(tool_input) result_dict = result.dict() else: diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index ec6c6a3308b..23d352cdb0b 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2680,3 +2680,77 @@ def test_tool_args_schema_with_annotated_type() -> None: "type": "array", } } + + +def test_injected_args_exclusion_from_schema_bug() -> None: + """Test that demonstrates the bug where injected args must be in args_schema. + + ISSUE: InjectedToolCallId parameters must be explicitly included in the + tool's args_schema for injection to work, even though they should be + automatically excluded from the LLM schema. + + EXPECTED BEHAVIOR: + - InjectedToolCallId should work without being in args_schema + - These parameters should be automatically injected by the tool framework + - The LLM should never see these parameters in the tool schema + + ACTUAL BEHAVIOR: + - InjectedToolCallId must be included in args_schema + - Without them in args_schema, injection fails with missing parameter error + """ + + # Test case 1: Tool with injected arg NOT in args_schema (should work but currently fails) + class BrokenToolSchema(BaseModel): + """This should work but currently doesn't - injection fails.""" + + text: str = Field(description="Text to process") + # InjectedToolCallId should NOT need to be here + + @tool( + "broken_tool", + description="Tool that should work with automatic injection", + args_schema=BrokenToolSchema, + ) + def broken_tool( + text: str, + tool_call_id: Annotated[str, InjectedToolCallId], + ) -> str: + """This tool should work but currently fails because injected params aren't in args_schema.""" + return f"Processed '{text}' (call_id: {tool_call_id})" + + # Test tool call structure + tool_call = { + "name": "test_tool", + "args": {"text": "test data"}, + "id": "test_123", + "type": "tool_call", + } + + # The broken tool should work but currently fails + with pytest.raises(ValidationError, match="1 validation error"): + # This should work: the tool_call_id should be automatically injected + # from the tool call structure, but currently fails because tool_call_id + # is not in the args_schema + broken_tool.invoke(tool_call) + + # Verify that tool_call_schema correctly excludes injected args for both tools + broken_schema = _get_tool_call_json_schema(broken_tool) + + # Both should have the same tool call schema (excluding injected args) + expected_schema_properties = { + "text": {"description": "Text to process", "title": "Text", "type": "string"} + } + assert broken_schema["properties"] == expected_schema_properties + + # Neither schema should include tool_call_id in the tool call schema + assert "tool_call_id" not in broken_schema["properties"] + + # But the full input schema should include tool_call_id for validation + broken_input_schema = _schema(broken_tool.get_input_schema()) + + # The broken tool's input schema should also include tool_call_id for validation, + # but this is where the bug manifests - it's missing + # This assertion documents the current buggy behavior: + assert ( + "tool_call_id" not in broken_input_schema["properties"] + ) # BUG: should be present