diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index f93c21d5697..b5d86c830da 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -384,6 +384,10 @@ class ToolException(Exception): # noqa: N818 ArgsSchema = Union[TypeBaseModel, dict[str, Any]] +"""Used to define the schema for tool arguments. + +This schema should not include injected arguments like `InjectedToolCallId`. +""" class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]): @@ -575,11 +579,17 @@ class ChildTool(BaseTool): ) -> type[BaseModel]: """The tool's input schema. + All required arguments will be included as part of this schema, including + any injected arguments like `InjectedToolCallId`. + + A schema without injected arguments can be obtained using the + `tool_call_schema` property instead. + Args: config: The configuration for the tool. Returns: - The input schema for the tool. + The full input schema for the tool, including injected arguments. """ if self.args_schema is not None: if isinstance(self.args_schema, dict): @@ -626,31 +636,29 @@ class ChildTool(BaseTool): InjectedToolCallId is required but not provided. NotImplementedError: If args_schema is not a supported type. """ - input_args = self.args_schema + args_schema = self.args_schema if isinstance(tool_input, str): - if input_args is not None: - if isinstance(input_args, dict): + if args_schema is not None: + if isinstance(args_schema, dict): msg = ( "String tool inputs are not allowed when " "using tools with JSON schema args_schema." ) raise ValueError(msg) - key_ = next(iter(get_fields(input_args).keys())) - if issubclass(input_args, BaseModel): - input_args.model_validate({key_: tool_input}) - elif issubclass(input_args, BaseModelV1): - input_args.parse_obj({key_: tool_input}) + key_ = next(iter(get_fields(args_schema).keys())) + if issubclass(args_schema, BaseModel): + args_schema.model_validate({key_: tool_input}) + elif issubclass(args_schema, BaseModelV1): + args_schema.parse_obj({key_: tool_input}) else: - msg = f"args_schema must be a Pydantic BaseModel, got {input_args}" + msg = f"args_schema must be a Pydantic BaseModel, got {args_schema}" raise TypeError(msg) return tool_input - if input_args is not None: - if isinstance(input_args, dict): + if args_schema is not None: + if isinstance(args_schema, dict): return tool_input - if issubclass(input_args, BaseModel): - # Check args_schema for injected arguments - schema_annotations = get_all_basemodel_annotations(input_args) - for k, v in schema_annotations.items(): + if issubclass(args_schema, BaseModel): + for k, v in get_all_basemodel_annotations(args_schema).items(): if ( _is_injected_arg_type(v, injected_type=InjectedToolCallId) and k not in tool_input @@ -665,58 +673,10 @@ class ChildTool(BaseTool): ) raise ValueError(msg) tool_input[k] = tool_call_id - - # Also check function signature for injected arguments not in schema - func_to_check = None - if hasattr(self, 'func') and self.func is not None: - # For StructuredTool, check the original function - func_to_check = self.func - elif hasattr(self, 'coroutine') and self.coroutine is not None: - # For async StructuredTool, check the original coroutine - func_to_check = self.coroutine - else: - # For other tools, check the _run method - func_to_check = self._run - - if func_to_check: - try: - func_annotations = get_type_hints(func_to_check, include_extras=True) - for param_name, param_type in func_annotations.items(): - # Check for InjectedToolCallId specifically - if ( - _is_injected_arg_type(param_type, injected_type=InjectedToolCallId) - and param_name not in schema_annotations - and param_name not in tool_input - ): - if tool_call_id is None: - msg = ( - "When tool includes an InjectedToolCallId " - "argument, tool must always be invoked with a full " - "model ToolCall of the form: {'args': {...}, " - "'name': '...', 'type': 'tool_call', " - "'tool_call_id': '...'}" - ) - raise ValueError(msg) - tool_input[param_name] = tool_call_id - # Check for any other injected arguments (for extensibility) - elif ( - _is_injected_arg_type(param_type) - and param_name not in schema_annotations - and param_name not in tool_input - ): - # For other injected types, we don't inject values automatically - # but we allow them to be missing from validation - pass - except (TypeError, AttributeError): - # Handle cases where get_type_hints fails - pass - - result = input_args.model_validate(tool_input) + result = args_schema.model_validate(tool_input) result_dict = result.model_dump() - elif issubclass(input_args, BaseModelV1): - # Check args_schema for injected arguments - schema_annotations = get_all_basemodel_annotations(input_args) - for k, v in schema_annotations.items(): + elif issubclass(args_schema, BaseModelV1): + for k, v in get_all_basemodel_annotations(args_schema).items(): if ( _is_injected_arg_type(v, injected_type=InjectedToolCallId) and k not in tool_input @@ -731,53 +691,7 @@ class ChildTool(BaseTool): ) raise ValueError(msg) tool_input[k] = tool_call_id - - # Also check function signature for injected arguments not in schema - func_to_check = None - if hasattr(self, 'func') and self.func is not None: - # For StructuredTool, check the original function - func_to_check = self.func - elif hasattr(self, 'coroutine') and self.coroutine is not None: - # For async StructuredTool, check the original coroutine - func_to_check = self.coroutine - else: - # For other tools, check the _run method - func_to_check = self._run - - if func_to_check: - try: - func_annotations = get_type_hints(func_to_check, include_extras=True) - for param_name, param_type in func_annotations.items(): - # Check for InjectedToolCallId specifically - if ( - _is_injected_arg_type(param_type, injected_type=InjectedToolCallId) - and param_name not in schema_annotations - and param_name not in tool_input - ): - if tool_call_id is None: - msg = ( - "When tool includes an InjectedToolCallId " - "argument, tool must always be invoked with a full " - "model ToolCall of the form: {'args': {...}, " - "'name': '...', 'type': 'tool_call', " - "'tool_call_id': '...'}" - ) - raise ValueError(msg) - tool_input[param_name] = tool_call_id - # Check for any other injected arguments (for extensibility) - elif ( - _is_injected_arg_type(param_type) - and param_name not in schema_annotations - and param_name not in tool_input - ): - # For other injected types, we don't inject values automatically - # but we allow them to be missing from validation - pass - except (TypeError, AttributeError): - # Handle cases where get_type_hints fails - pass - - result = input_args.parse_obj(tool_input) + result = args_schema.parse_obj(tool_input) result_dict = result.dict() else: msg = ( diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 23d352cdb0b..8416f295340 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2682,26 +2682,11 @@ def test_tool_args_schema_with_annotated_type() -> None: } -def test_injected_args_exclusion_from_schema_bug() -> None: - """Test that demonstrates the bug where injected args must be in args_schema. +def test_args_schema_without_injected_arguments() -> None: + """Test that injected args work correctly in tool schemas.""" - ISSUE: InjectedToolCallId parameters must be explicitly included in the - tool's args_schema for injection to work, even though they should be - automatically excluded from the LLM schema. - - EXPECTED BEHAVIOR: - - InjectedToolCallId should work without being in args_schema - - These parameters should be automatically injected by the tool framework - - The LLM should never see these parameters in the tool schema - - ACTUAL BEHAVIOR: - - InjectedToolCallId must be included in args_schema - - Without them in args_schema, injection fails with missing parameter error - """ - - # Test case 1: Tool with injected arg NOT in args_schema (should work but currently fails) - class BrokenToolSchema(BaseModel): - """This should work but currently doesn't - injection fails.""" + class CustomSchema(BaseModel): + """Schema for LLM inputs.""" text: str = Field(description="Text to process") # InjectedToolCallId should NOT need to be here @@ -2709,13 +2694,14 @@ def test_injected_args_exclusion_from_schema_bug() -> None: @tool( "broken_tool", description="Tool that should work with automatic injection", - args_schema=BrokenToolSchema, + parse_docstring=True, + args_schema=CustomSchema, ) def broken_tool( text: str, tool_call_id: Annotated[str, InjectedToolCallId], ) -> str: - """This tool should work but currently fails because injected params aren't in args_schema.""" + """Tool with injected arguments.""" return f"Processed '{text}' (call_id: {tool_call_id})" # Test tool call structure @@ -2726,31 +2712,45 @@ def test_injected_args_exclusion_from_schema_bug() -> None: "type": "tool_call", } - # The broken tool should work but currently fails - with pytest.raises(ValidationError, match="1 validation error"): - # This should work: the tool_call_id should be automatically injected - # from the tool call structure, but currently fails because tool_call_id - # is not in the args_schema - broken_tool.invoke(tool_call) - - # Verify that tool_call_schema correctly excludes injected args for both tools - broken_schema = _get_tool_call_json_schema(broken_tool) - - # Both should have the same tool call schema (excluding injected args) - expected_schema_properties = { - "text": {"description": "Text to process", "title": "Text", "type": "string"} - } - assert broken_schema["properties"] == expected_schema_properties - - # Neither schema should include tool_call_id in the tool call schema - assert "tool_call_id" not in broken_schema["properties"] + result = broken_tool.invoke(tool_call) + assert isinstance(result, ToolMessage) + assert result.content == "Processed 'test data' (call_id: test_123)" # But the full input schema should include tool_call_id for validation - broken_input_schema = _schema(broken_tool.get_input_schema()) + schema = _schema(broken_tool.get_input_schema()) + assert schema == { + "description": "Tool with injected arguments.", + "properties": { + "text": { + "title": "Text", + "type": "string", + "description": "Text to process", + }, + # Tool call ID **is** required in this case. + "tool_call_id": { + "title": "Tool Call Id", + "type": "string", + }, + }, + "required": [ + "text", + "tool_call_id", + ], + "title": "broken_tool", + "type": "object", + } - # The broken tool's input schema should also include tool_call_id for validation, - # but this is where the bug manifests - it's missing - # This assertion documents the current buggy behavior: - assert ( - "tool_call_id" not in broken_input_schema["properties"] - ) # BUG: should be present + tool_call_schema = _schema(broken_tool.tool_call_schema) + assert tool_call_schema == { + "description": "Tool that should work with automatic injection", + "properties": { + "text": { + "description": "Text to process", + "title": "Text", + "type": "string", + } + }, + "required": ["text"], + "title": "broken_tool", + "type": "object", + }