chore: Support tool runtime injection when custom args schema is prov… (#33999)

Support injection of injected args (like `InjectedToolCallId`,
`ToolRuntime`) when an `args_schema` is specified that doesn't contain
said args.

This allows for pydantic validation of other args while retaining the
ability to inject langchain specific arguments.

fixes https://github.com/langchain-ai/langchain/issues/33646
fixes https://github.com/langchain-ai/langchain/issues/31688

Taking a deep dive here reminded me that we definitely need to revisit
our internal tooling logic, but I don't think we should do that in this
PR.

---------

Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
Co-authored-by: Sydney Runkle <sydneymarierunkle@gmail.com>
This commit is contained in:
William FH
2025-11-18 09:09:59 -08:00
committed by GitHub
parent 990e346c46
commit 32bbe99efc
3 changed files with 138 additions and 2 deletions

View File

@@ -386,6 +386,8 @@ class ToolException(Exception): # noqa: N818
ArgsSchema = TypeBaseModel | dict[str, Any]
_EMPTY_SET: frozenset[str] = frozenset()
class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
"""Base class for all LangChain tools.
@@ -569,6 +571,11 @@ class ChildTool(BaseTool):
self.name, full_schema, fields, fn_description=self.description
)
@functools.cached_property
def _injected_args_keys(self) -> frozenset[str]:
# base implementation doesn't manage injected args
return _EMPTY_SET
# --- Runnable ---
@override
@@ -649,6 +656,7 @@ class ChildTool(BaseTool):
if isinstance(input_args, dict):
return tool_input
if issubclass(input_args, BaseModel):
# Check args_schema for InjectedToolCallId
for k, v in get_all_basemodel_annotations(input_args).items():
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
if tool_call_id is None:
@@ -664,6 +672,7 @@ class ChildTool(BaseTool):
result = input_args.model_validate(tool_input)
result_dict = result.model_dump()
elif issubclass(input_args, BaseModelV1):
# Check args_schema for InjectedToolCallId
for k, v in get_all_basemodel_annotations(input_args).items():
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
if tool_call_id is None:
@@ -683,9 +692,25 @@ class ChildTool(BaseTool):
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
)
raise NotImplementedError(msg)
return {
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
validated_input = {
k: getattr(result, k) for k in result_dict if k in tool_input
}
for k in self._injected_args_keys:
if k == "tool_call_id":
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
validated_input[k] = tool_call_id
if k in tool_input:
injected_val = tool_input[k]
validated_input[k] = injected_val
return validated_input
return tool_input
@abstractmethod

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import functools
import textwrap
from collections.abc import Awaitable, Callable
from inspect import signature
@@ -21,10 +22,12 @@ from langchain_core.callbacks import (
)
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
_EMPTY_SET,
FILTERED_ARGS,
ArgsSchema,
BaseTool,
_get_runnable_config_param,
_is_injected_arg_type,
create_schema_from_function,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -241,6 +244,17 @@ class StructuredTool(BaseTool):
**kwargs,
)
@functools.cached_property
def _injected_args_keys(self) -> frozenset[str]:
fn = self.func or self.coroutine
if fn is None:
return _EMPTY_SET
return frozenset(
k
for k, v in signature(fn).parameters.items()
if _is_injected_arg_type(v.annotation)
)
def _filter_schema_args(func: Callable) -> list[str]:
filter_args = list(FILTERED_ARGS)

View File

@@ -6,6 +6,7 @@ import sys
import textwrap
import threading
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from functools import partial
@@ -55,6 +56,7 @@ from langchain_core.tools.base import (
InjectedToolArg,
InjectedToolCallId,
SchemaAnnotationError,
_DirectlyInjectedToolArg,
_is_message_content_block,
_is_message_content_type,
get_all_basemodel_annotations,
@@ -2331,6 +2333,101 @@ def test_injected_arg_with_complex_type() -> None:
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar"
@pytest.mark.parametrize("schema_format", ["model", "json_schema"])
def test_tool_allows_extra_runtime_args_with_custom_schema(
schema_format: Literal["model", "json_schema"],
) -> None:
"""Ensure runtime args are preserved even if not in the args schema."""
class InputSchema(BaseModel):
query: str
captured: dict[str, Any] = {}
@dataclass
class MyRuntime(_DirectlyInjectedToolArg):
some_obj: object
args_schema = (
InputSchema if schema_format == "model" else InputSchema.model_json_schema()
)
@tool(args_schema=args_schema)
def runtime_tool(query: str, runtime: MyRuntime) -> str:
"""Echo the query and capture runtime value."""
captured["runtime"] = runtime
return query
runtime_obj = object()
runtime = MyRuntime(some_obj=runtime_obj)
assert runtime_tool.invoke({"query": "hello", "runtime": runtime}) == "hello"
assert captured["runtime"] is runtime
def test_tool_injected_tool_call_id_with_custom_schema() -> None:
"""Ensure InjectedToolCallId works with custom args schema."""
class InputSchema(BaseModel):
x: int
@tool(args_schema=InputSchema)
def injected_tool(
x: int, tool_call_id: Annotated[str, InjectedToolCallId]
) -> ToolMessage:
"""Tool with injected tool_call_id and custom schema."""
return ToolMessage(str(x), tool_call_id=tool_call_id)
# Test that tool_call_id is properly injected even though not in custom schema
result = injected_tool.invoke(
{
"type": "tool_call",
"args": {"x": 42},
"name": "injected_tool",
"id": "test_call_id",
}
)
assert result == ToolMessage("42", tool_call_id="test_call_id")
# Test that it still raises error when invoked without a ToolCall
with pytest.raises(
ValueError,
match="When tool includes an InjectedToolCallId argument, "
"tool must always be invoked with a full model ToolCall",
):
injected_tool.invoke({"x": 42})
def test_tool_injected_arg_with_custom_schema() -> None:
"""Ensure InjectedToolArg works with custom args schema."""
class InputSchema(BaseModel):
query: str
class CustomContext:
"""Custom context object to be injected."""
def __init__(self, value: str) -> None:
self.value = value
captured: dict[str, Any] = {}
@tool(args_schema=InputSchema)
def search_tool(
query: str, context: Annotated[CustomContext, InjectedToolArg]
) -> str:
"""Search with custom context."""
captured["context"] = context
return f"Results for {query} with context {context.value}"
# Test that context is properly injected even though not in custom schema
ctx = CustomContext("test_context")
result = search_tool.invoke({"query": "hello", "context": ctx})
assert result == "Results for hello with context test_context"
assert captured["context"] is ctx
assert captured["context"].value == "test_context"
def test_tool_injected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage: