This commit is contained in:
Eugene Yurtsev 2025-03-28 13:51:06 -04:00
parent 92dc3f7341
commit 6d22f40a0b
2 changed files with 92 additions and 66 deletions

View File

@ -650,8 +650,16 @@ class ChildTool(BaseTool):
# pass as a positional argument. # pass as a positional argument.
if isinstance(tool_input, str): if isinstance(tool_input, str):
return (tool_input,), {} return (tool_input,), {}
elif isinstance(tool_input, dict):
# Make a shallow copy of the input to allow downstream code
# to modify the root level of the input without affecting the
# original input.
# This is used by the tool to inject run time information like
# the callback manager.
return (), tool_input.copy()
else: else:
return (), tool_input # This code path is not expected to be reachable.
raise TypeError(f"Invalid tool input type: {type(tool_input)}")
def run( def run(
self, self,
@ -946,14 +954,6 @@ def _prep_run_args(
else: else:
tool_call_id = None tool_call_id = None
tool_input = cast(Union[str, dict], input) tool_input = cast(Union[str, dict], input)
if not isinstance(tool_input, str):
try:
tool_input = tool_input.copy()
except Exception:
import copy
tool_input = copy.deepcopy(tool_input)
return ( return (
tool_input, tool_input,
dict( dict(

View File

@ -2,12 +2,16 @@
import inspect import inspect
import json import json
import pytest
import sys import sys
import textwrap import textwrap
import threading import threading
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from pydantic import BaseModel, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
@ -19,11 +23,6 @@ from typing import (
Union, Union,
cast, cast,
) )
import pytest
from pydantic import BaseModel, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain_core import tools from langchain_core import tools
@ -35,7 +34,7 @@ from langchain_core.callbacks.manager import (
CallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun,
) )
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.messages import AIMessage, ToolCall, ToolMessage from langchain_core.messages import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolOutputMixin from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import ( from langchain_core.runnables import (
@ -2606,62 +2605,89 @@ def test_title_property_preserved() -> None:
} }
@pytest.mark.asyncio async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
async def test_ainvoke_input_isolation_from_state_graph_context() -> None: """Verify that the inputs are not mutated when invoking a tool asynchronously."""
"""Test that tool execution works correctly with LangGraph."""
import json
from blockbuster.blockbuster import blockbuster_skip def sync_no_op(foo: int) -> str:
from langchain_core.messages import (
HumanMessage,
SystemMessage,
)
blockbuster_skip.set(True)
# Setup
prompt = "Help user with his/her requests"
# Create test tool
async def sleep(**arguments: dict[str, Any]) -> str:
return "good" return "good"
my_tool = [ async def async_no_op(foo: int) -> str:
StructuredTool( return "good"
name="sleep",
description="Sleep for a while", tool = StructuredTool(
args_schema={ name="sample_tool",
"type": "object", description="",
"required": ["seconds"], args_schema={
"properties": { "type": "object",
"seconds": {"type": "number", "description": "How long to sleep"} "required": ["foo"],
}, "properties": {
"seconds": {"type": "number", "description": "How big is foo"}
}, },
coroutine=sleep, },
func=sleep, coroutine=async_no_op,
) func=sync_no_op,
] )
_tools = {t.name: t for t in my_tool}
tool_calls = """ tool_call = {
[{ "name": "sample_tool",
"name": "sleep", "args": {"foo": 2},
"args": {"seconds": 2}, "id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134", "type": "tool_call",
"type": "tool_call"}] }
"""
# Test execution assert tool.invoke(tool_call["args"]) == "good"
messages: list[Any] = [] assert tool_call == {
_input = "sleep for 2 seconds!" "name": "sample_tool",
messages.append(SystemMessage(content=prompt)) "args": {"foo": 2},
messages.append(HumanMessage(content=_input)) "id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
ai_message = AIMessage(tool_calls=json.loads(tool_calls), content="") "type": "tool_call",
messages.append(ai_message) }
result = await _tools["sleep"].ainvoke(ai_message.tool_calls[0]["args"]) assert await tool.ainvoke(tool_call["args"]) == "good"
assert tool_call == {
"name": "sample_tool",
"args": {"foo": 2},
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"type": "tool_call",
}
def test_tool_invoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
def sync_no_op(foo: int) -> str:
return "good"
async def async_no_op(foo: int) -> str:
return "good"
tool = StructuredTool(
name="sample_tool",
description="",
args_schema={
"type": "object",
"required": ["foo"],
"properties": {
"seconds": {"type": "number", "description": "How big is foo"}
},
},
coroutine=async_no_op,
func=sync_no_op,
)
tool_call = {
"name": "sample_tool",
"args": {"foo": 2},
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"type": "tool_call",
}
assert tool.invoke(tool_call["args"]) == "good"
assert tool_call == {
"name": "sample_tool",
"args": {"foo": 2},
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"type": "tool_call",
}
# Assertions
assert result == "good"
assert "run_manager" not in ai_message.tool_calls[0]["args"]