This commit is contained in:
Eugene Yurtsev 2025-03-28 13:51:06 -04:00
parent 92dc3f7341
commit 6d22f40a0b
2 changed files with 92 additions and 66 deletions

View File

@ -650,8 +650,16 @@ class ChildTool(BaseTool):
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
elif isinstance(tool_input, dict):
# Make a shallow copy of the input to allow downstream code
# to modify the root level of the input without affecting the
# original input.
# This is used by the tool to inject run time information like
# the callback manager.
return (), tool_input.copy()
else:
return (), tool_input
# This code path is not expected to be reachable.
raise TypeError(f"Invalid tool input type: {type(tool_input)}")
def run(
self,
@ -946,14 +954,6 @@ def _prep_run_args(
else:
tool_call_id = None
tool_input = cast(Union[str, dict], input)
if not isinstance(tool_input, str):
try:
tool_input = tool_input.copy()
except Exception:
import copy
tool_input = copy.deepcopy(tool_input)
return (
tool_input,
dict(

View File

@ -2,12 +2,16 @@
import inspect
import json
import pytest
import sys
import textwrap
import threading
from datetime import datetime
from enum import Enum
from functools import partial
from pydantic import BaseModel, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from typing import (
Annotated,
Any,
@ -19,11 +23,6 @@ from typing import (
Union,
cast,
)
import pytest
from pydantic import BaseModel, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from typing_extensions import TypedDict
from langchain_core import tools
@ -35,7 +34,7 @@ from langchain_core.callbacks.manager import (
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langchain_core.messages import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
@ -2606,62 +2605,89 @@ def test_title_property_preserved() -> None:
}
@pytest.mark.asyncio
async def test_ainvoke_input_isolation_from_state_graph_context() -> None:
"""Test that tool execution works correctly with LangGraph."""
import json
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
from blockbuster.blockbuster import blockbuster_skip
from langchain_core.messages import (
HumanMessage,
SystemMessage,
)
blockbuster_skip.set(True)
# Setup
prompt = "Help user with his/her requests"
# Create test tool
async def sleep(**arguments: dict[str, Any]) -> str:
def sync_no_op(foo: int) -> str:
return "good"
my_tool = [
StructuredTool(
name="sleep",
description="Sleep for a while",
args_schema={
"type": "object",
"required": ["seconds"],
"properties": {
"seconds": {"type": "number", "description": "How long to sleep"}
},
async def async_no_op(foo: int) -> str:
return "good"
tool = StructuredTool(
name="sample_tool",
description="",
args_schema={
"type": "object",
"required": ["foo"],
"properties": {
"seconds": {"type": "number", "description": "How big is foo"}
},
coroutine=sleep,
func=sleep,
)
]
_tools = {t.name: t for t in my_tool}
},
coroutine=async_no_op,
func=sync_no_op,
)
tool_calls = """
[{
"name": "sleep",
"args": {"seconds": 2},
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"type": "tool_call"}]
"""
tool_call = {
"name": "sample_tool",
"args": {"foo": 2},
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"type": "tool_call",
}
# Test execution
messages: list[Any] = []
_input = "sleep for 2 seconds!"
messages.append(SystemMessage(content=prompt))
messages.append(HumanMessage(content=_input))
ai_message = AIMessage(tool_calls=json.loads(tool_calls), content="")
messages.append(ai_message)
assert tool.invoke(tool_call["args"]) == "good"
assert tool_call == {
"name": "sample_tool",
"args": {"foo": 2},
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"type": "tool_call",
}
result = await _tools["sleep"].ainvoke(ai_message.tool_calls[0]["args"])
assert await tool.ainvoke(tool_call["args"]) == "good"
assert tool_call == {
"name": "sample_tool",
"args": {"foo": 2},
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"type": "tool_call",
}
def test_tool_invoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
def sync_no_op(foo: int) -> str:
return "good"
async def async_no_op(foo: int) -> str:
return "good"
tool = StructuredTool(
name="sample_tool",
description="",
args_schema={
"type": "object",
"required": ["foo"],
"properties": {
"seconds": {"type": "number", "description": "How big is foo"}
},
},
coroutine=async_no_op,
func=sync_no_op,
)
tool_call = {
"name": "sample_tool",
"args": {"foo": 2},
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"type": "tool_call",
}
assert tool.invoke(tool_call["args"]) == "good"
assert tool_call == {
"name": "sample_tool",
"args": {"foo": 2},
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"type": "tool_call",
}
# Assertions
assert result == "good"
assert "run_manager" not in ai_message.tool_calls[0]["args"]