mirror of
https://github.com/hwchase17/langchain.git
synced 2025-12-26 17:36:35 +00:00
chore: Support tool runtime injection when custom args schema is prov… (#33999)
Support injection of injected args (like `InjectedToolCallId`, `ToolRuntime`) when an `args_schema` is specified that doesn't contain said args. This allows for pydantic validation of other args while retaining the ability to inject langchain specific arguments. fixes https://github.com/langchain-ai/langchain/issues/33646 fixes https://github.com/langchain-ai/langchain/issues/31688 Taking a deep dive here reminded me that we definitely need to revisit our internal tooling logic, but I don't think we should do that in this PR. --------- Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Co-authored-by: Sydney Runkle <sydneymarierunkle@gmail.com>
This commit is contained in:
@@ -386,6 +386,8 @@ class ToolException(Exception): # noqa: N818
|
||||
|
||||
ArgsSchema = TypeBaseModel | dict[str, Any]
|
||||
|
||||
_EMPTY_SET: frozenset[str] = frozenset()
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
|
||||
"""Base class for all LangChain tools.
|
||||
@@ -569,6 +571,11 @@ class ChildTool(BaseTool):
|
||||
self.name, full_schema, fields, fn_description=self.description
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_args_keys(self) -> frozenset[str]:
|
||||
# base implementation doesn't manage injected args
|
||||
return _EMPTY_SET
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
@override
|
||||
@@ -649,6 +656,7 @@ class ChildTool(BaseTool):
|
||||
if isinstance(input_args, dict):
|
||||
return tool_input
|
||||
if issubclass(input_args, BaseModel):
|
||||
# Check args_schema for InjectedToolCallId
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
if tool_call_id is None:
|
||||
@@ -664,6 +672,7 @@ class ChildTool(BaseTool):
|
||||
result = input_args.model_validate(tool_input)
|
||||
result_dict = result.model_dump()
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
# Check args_schema for InjectedToolCallId
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
if tool_call_id is None:
|
||||
@@ -683,9 +692,25 @@ class ChildTool(BaseTool):
|
||||
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return {
|
||||
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
|
||||
validated_input = {
|
||||
k: getattr(result, k) for k in result_dict if k in tool_input
|
||||
}
|
||||
for k in self._injected_args_keys:
|
||||
if k == "tool_call_id":
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
validated_input[k] = tool_call_id
|
||||
if k in tool_input:
|
||||
injected_val = tool_input[k]
|
||||
validated_input[k] = injected_val
|
||||
return validated_input
|
||||
return tool_input
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import textwrap
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import signature
|
||||
@@ -21,10 +22,12 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||
from langchain_core.tools.base import (
|
||||
_EMPTY_SET,
|
||||
FILTERED_ARGS,
|
||||
ArgsSchema,
|
||||
BaseTool,
|
||||
_get_runnable_config_param,
|
||||
_is_injected_arg_type,
|
||||
create_schema_from_function,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
@@ -241,6 +244,17 @@ class StructuredTool(BaseTool):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_args_keys(self) -> frozenset[str]:
|
||||
fn = self.func or self.coroutine
|
||||
if fn is None:
|
||||
return _EMPTY_SET
|
||||
return frozenset(
|
||||
k
|
||||
for k, v in signature(fn).parameters.items()
|
||||
if _is_injected_arg_type(v.annotation)
|
||||
)
|
||||
|
||||
|
||||
def _filter_schema_args(func: Callable) -> list[str]:
|
||||
filter_args = list(FILTERED_ARGS)
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import textwrap
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
@@ -55,6 +56,7 @@ from langchain_core.tools.base import (
|
||||
InjectedToolArg,
|
||||
InjectedToolCallId,
|
||||
SchemaAnnotationError,
|
||||
_DirectlyInjectedToolArg,
|
||||
_is_message_content_block,
|
||||
_is_message_content_type,
|
||||
get_all_basemodel_annotations,
|
||||
@@ -2331,6 +2333,101 @@ def test_injected_arg_with_complex_type() -> None:
|
||||
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema_format", ["model", "json_schema"])
|
||||
def test_tool_allows_extra_runtime_args_with_custom_schema(
|
||||
schema_format: Literal["model", "json_schema"],
|
||||
) -> None:
|
||||
"""Ensure runtime args are preserved even if not in the args schema."""
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
query: str
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
@dataclass
|
||||
class MyRuntime(_DirectlyInjectedToolArg):
|
||||
some_obj: object
|
||||
|
||||
args_schema = (
|
||||
InputSchema if schema_format == "model" else InputSchema.model_json_schema()
|
||||
)
|
||||
|
||||
@tool(args_schema=args_schema)
|
||||
def runtime_tool(query: str, runtime: MyRuntime) -> str:
|
||||
"""Echo the query and capture runtime value."""
|
||||
captured["runtime"] = runtime
|
||||
return query
|
||||
|
||||
runtime_obj = object()
|
||||
runtime = MyRuntime(some_obj=runtime_obj)
|
||||
assert runtime_tool.invoke({"query": "hello", "runtime": runtime}) == "hello"
|
||||
assert captured["runtime"] is runtime
|
||||
|
||||
|
||||
def test_tool_injected_tool_call_id_with_custom_schema() -> None:
|
||||
"""Ensure InjectedToolCallId works with custom args schema."""
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
x: int
|
||||
|
||||
@tool(args_schema=InputSchema)
|
||||
def injected_tool(
|
||||
x: int, tool_call_id: Annotated[str, InjectedToolCallId]
|
||||
) -> ToolMessage:
|
||||
"""Tool with injected tool_call_id and custom schema."""
|
||||
return ToolMessage(str(x), tool_call_id=tool_call_id)
|
||||
|
||||
# Test that tool_call_id is properly injected even though not in custom schema
|
||||
result = injected_tool.invoke(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"args": {"x": 42},
|
||||
"name": "injected_tool",
|
||||
"id": "test_call_id",
|
||||
}
|
||||
)
|
||||
assert result == ToolMessage("42", tool_call_id="test_call_id")
|
||||
|
||||
# Test that it still raises error when invoked without a ToolCall
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="When tool includes an InjectedToolCallId argument, "
|
||||
"tool must always be invoked with a full model ToolCall",
|
||||
):
|
||||
injected_tool.invoke({"x": 42})
|
||||
|
||||
|
||||
def test_tool_injected_arg_with_custom_schema() -> None:
|
||||
"""Ensure InjectedToolArg works with custom args schema."""
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
query: str
|
||||
|
||||
class CustomContext:
|
||||
"""Custom context object to be injected."""
|
||||
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
@tool(args_schema=InputSchema)
|
||||
def search_tool(
|
||||
query: str, context: Annotated[CustomContext, InjectedToolArg]
|
||||
) -> str:
|
||||
"""Search with custom context."""
|
||||
captured["context"] = context
|
||||
return f"Results for {query} with context {context.value}"
|
||||
|
||||
# Test that context is properly injected even though not in custom schema
|
||||
ctx = CustomContext("test_context")
|
||||
result = search_tool.invoke({"query": "hello", "context": ctx})
|
||||
|
||||
assert result == "Results for hello with context test_context"
|
||||
assert captured["context"] is ctx
|
||||
assert captured["context"].value == "test_context"
|
||||
|
||||
|
||||
def test_tool_injected_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
|
||||
|
||||
Reference in New Issue
Block a user