diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 64648273e0f..e2da7a5685f 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -386,6 +386,8 @@ class ToolException(Exception): # noqa: N818 ArgsSchema = TypeBaseModel | dict[str, Any] +_EMPTY_SET: frozenset[str] = frozenset() + class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]): """Base class for all LangChain tools. @@ -569,6 +571,11 @@ class ChildTool(BaseTool): self.name, full_schema, fields, fn_description=self.description ) + @functools.cached_property + def _injected_args_keys(self) -> frozenset[str]: + # base implementation doesn't manage injected args + return _EMPTY_SET + # --- Runnable --- @override @@ -649,6 +656,7 @@ class ChildTool(BaseTool): if isinstance(input_args, dict): return tool_input if issubclass(input_args, BaseModel): + # Check args_schema for InjectedToolCallId for k, v in get_all_basemodel_annotations(input_args).items(): if _is_injected_arg_type(v, injected_type=InjectedToolCallId): if tool_call_id is None: @@ -664,6 +672,7 @@ class ChildTool(BaseTool): result = input_args.model_validate(tool_input) result_dict = result.model_dump() elif issubclass(input_args, BaseModelV1): + # Check args_schema for InjectedToolCallId for k, v in get_all_basemodel_annotations(input_args).items(): if _is_injected_arg_type(v, injected_type=InjectedToolCallId): if tool_call_id is None: @@ -683,9 +692,25 @@ class ChildTool(BaseTool): f"args_schema must be a Pydantic BaseModel, got {self.args_schema}" ) raise NotImplementedError(msg) - return { - k: getattr(result, k) for k, v in result_dict.items() if k in tool_input + validated_input = { + k: getattr(result, k) for k in result_dict if k in tool_input } + for k in self._injected_args_keys: + if k == "tool_call_id": + if tool_call_id is None: + msg = ( + "When tool includes an InjectedToolCallId " + "argument, tool must always be invoked with a full " + "model ToolCall of the form: {'args': {...}, " + "'name': '...', 'type': 'tool_call', " + "'tool_call_id': '...'}" + ) + raise ValueError(msg) + validated_input[k] = tool_call_id + if k in tool_input: + injected_val = tool_input[k] + validated_input[k] = injected_val + return validated_input return tool_input @abstractmethod diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index 43e981570a0..2b613f59461 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -2,6 +2,7 @@ from __future__ import annotations +import functools import textwrap from collections.abc import Awaitable, Callable from inspect import signature @@ -21,10 +22,12 @@ from langchain_core.callbacks import ( ) from langchain_core.runnables import RunnableConfig, run_in_executor from langchain_core.tools.base import ( + _EMPTY_SET, FILTERED_ARGS, ArgsSchema, BaseTool, _get_runnable_config_param, + _is_injected_arg_type, create_schema_from_function, ) from langchain_core.utils.pydantic import is_basemodel_subclass @@ -241,6 +244,17 @@ class StructuredTool(BaseTool): **kwargs, ) + @functools.cached_property + def _injected_args_keys(self) -> frozenset[str]: + fn = self.func or self.coroutine + if fn is None: + return _EMPTY_SET + return frozenset( + k + for k, v in signature(fn).parameters.items() + if _is_injected_arg_type(v.annotation) + ) + def _filter_schema_args(func: Callable) -> list[str]: filter_args = list(FILTERED_ARGS) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 6c7dc25bba7..87b3e004650 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -6,6 +6,7 @@ import sys import textwrap import threading from collections.abc import Callable +from dataclasses import dataclass from datetime import datetime from enum import Enum from functools import partial @@ -55,6 +56,7 @@ from langchain_core.tools.base import ( InjectedToolArg, InjectedToolCallId, SchemaAnnotationError, + _DirectlyInjectedToolArg, _is_message_content_block, _is_message_content_type, get_all_basemodel_annotations, @@ -2331,6 +2333,101 @@ def test_injected_arg_with_complex_type() -> None: assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" +@pytest.mark.parametrize("schema_format", ["model", "json_schema"]) +def test_tool_allows_extra_runtime_args_with_custom_schema( + schema_format: Literal["model", "json_schema"], +) -> None: + """Ensure runtime args are preserved even if not in the args schema.""" + + class InputSchema(BaseModel): + query: str + + captured: dict[str, Any] = {} + + @dataclass + class MyRuntime(_DirectlyInjectedToolArg): + some_obj: object + + args_schema = ( + InputSchema if schema_format == "model" else InputSchema.model_json_schema() + ) + + @tool(args_schema=args_schema) + def runtime_tool(query: str, runtime: MyRuntime) -> str: + """Echo the query and capture runtime value.""" + captured["runtime"] = runtime + return query + + runtime_obj = object() + runtime = MyRuntime(some_obj=runtime_obj) + assert runtime_tool.invoke({"query": "hello", "runtime": runtime}) == "hello" + assert captured["runtime"] is runtime + + +def test_tool_injected_tool_call_id_with_custom_schema() -> None: + """Ensure InjectedToolCallId works with custom args schema.""" + + class InputSchema(BaseModel): + x: int + + @tool(args_schema=InputSchema) + def injected_tool( + x: int, tool_call_id: Annotated[str, InjectedToolCallId] + ) -> ToolMessage: + """Tool with injected tool_call_id and custom schema.""" + return ToolMessage(str(x), tool_call_id=tool_call_id) + + # Test that tool_call_id is properly injected even though not in custom schema + result = injected_tool.invoke( + { + "type": "tool_call", + "args": {"x": 42}, + "name": "injected_tool", + "id": "test_call_id", + } + ) + assert result == ToolMessage("42", tool_call_id="test_call_id") + + # Test that it still raises error when invoked without a ToolCall + with pytest.raises( + ValueError, + match="When tool includes an InjectedToolCallId argument, " + "tool must always be invoked with a full model ToolCall", + ): + injected_tool.invoke({"x": 42}) + + +def test_tool_injected_arg_with_custom_schema() -> None: + """Ensure InjectedToolArg works with custom args schema.""" + + class InputSchema(BaseModel): + query: str + + class CustomContext: + """Custom context object to be injected.""" + + def __init__(self, value: str) -> None: + self.value = value + + captured: dict[str, Any] = {} + + @tool(args_schema=InputSchema) + def search_tool( + query: str, context: Annotated[CustomContext, InjectedToolArg] + ) -> str: + """Search with custom context.""" + captured["context"] = context + return f"Results for {query} with context {context.value}" + + # Test that context is properly injected even though not in custom schema + ctx = CustomContext("test_context") + result = search_tool.invoke({"query": "hello", "context": ctx}) + + assert result == "Results for hello with context test_context" + assert captured["context"] is ctx + assert captured["context"].value == "test_context" + + def test_tool_injected_tool_call_id() -> None: @tool def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage: