diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 6e26e39e8bd..0148528a7fd 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -650,8 +650,16 @@ class ChildTool(BaseTool): # pass as a positional argument. if isinstance(tool_input, str): return (tool_input,), {} + elif isinstance(tool_input, dict): + # Make a shallow copy of the input to allow downstream code + # to modify the root level of the input without affecting the + # original input. + # This is used by the tool to inject run time information like + # the callback manager. + return (), tool_input.copy() else: - return (), tool_input + # This code path is not expected to be reachable. + raise TypeError(f"Invalid tool input type: {type(tool_input)}") def run( self, @@ -946,14 +954,6 @@ def _prep_run_args( else: tool_call_id = None tool_input = cast(Union[str, dict], input) - if not isinstance(tool_input, str): - try: - tool_input = tool_input.copy() - except Exception: - import copy - - tool_input = copy.deepcopy(tool_input) - return ( tool_input, dict( diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index f48b45eaf80..c253d24271a 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2,12 +2,16 @@ import inspect import json +import pytest import sys import textwrap import threading from datetime import datetime from enum import Enum from functools import partial +from pydantic import BaseModel, Field, ValidationError +from pydantic.v1 import BaseModel as BaseModelV1 +from pydantic.v1 import ValidationError as ValidationErrorV1 from typing import ( Annotated, Any, @@ -19,11 +23,6 @@ from typing import ( Union, cast, ) - -import pytest -from pydantic import BaseModel, Field, ValidationError -from pydantic.v1 import BaseModel as BaseModelV1 -from pydantic.v1 import ValidationError as ValidationErrorV1 from typing_extensions import TypedDict from langchain_core import tools @@ -35,7 +34,7 @@ from langchain_core.callbacks.manager import ( CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document -from langchain_core.messages import AIMessage, ToolCall, ToolMessage +from langchain_core.messages import ToolCall, ToolMessage from langchain_core.messages.tool import ToolOutputMixin from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import ( @@ -2606,62 +2605,89 @@ def test_title_property_preserved() -> None: } -@pytest.mark.asyncio -async def test_ainvoke_input_isolation_from_state_graph_context() -> None: - """Test that tool execution works correctly with LangGraph.""" - import json +async def test_tool_ainvoke_does_not_mutate_inputs() -> None: + """Verify that the inputs are not mutated when invoking a tool asynchronously.""" - from blockbuster.blockbuster import blockbuster_skip - - from langchain_core.messages import ( - HumanMessage, - SystemMessage, - ) - - blockbuster_skip.set(True) - - # Setup - prompt = "Help user with his/her requests" - - # Create test tool - async def sleep(**arguments: dict[str, Any]) -> str: + def sync_no_op(foo: int) -> str: return "good" - my_tool = [ - StructuredTool( - name="sleep", - description="Sleep for a while", - args_schema={ - "type": "object", - "required": ["seconds"], - "properties": { - "seconds": {"type": "number", "description": "How long to sleep"} - }, + async def async_no_op(foo: int) -> str: + return "good" + + tool = StructuredTool( + name="sample_tool", + description="", + args_schema={ + "type": "object", + "required": ["foo"], + "properties": { + "seconds": {"type": "number", "description": "How big is foo"} }, - coroutine=sleep, - func=sleep, - ) - ] - _tools = {t.name: t for t in my_tool} + }, + coroutine=async_no_op, + func=sync_no_op, + ) - tool_calls = """ - [{ - "name": "sleep", - "args": {"seconds": 2}, - "id": "call_0_82c17db8-95df-452f-a4c2-03f809022134", - "type": "tool_call"}] - """ + tool_call = { + "name": "sample_tool", + "args": {"foo": 2}, + "id": "call_0_82c17db8-95df-452f-a4c2-03f809022134", + "type": "tool_call", + } - # Test execution - messages: list[Any] = [] - _input = "sleep for 2 seconds!" - messages.append(SystemMessage(content=prompt)) - messages.append(HumanMessage(content=_input)) - ai_message = AIMessage(tool_calls=json.loads(tool_calls), content="") - messages.append(ai_message) + assert tool.invoke(tool_call["args"]) == "good" + assert tool_call == { + "name": "sample_tool", + "args": {"foo": 2}, + "id": "call_0_82c17db8-95df-452f-a4c2-03f809022134", + "type": "tool_call", + } - result = await _tools["sleep"].ainvoke(ai_message.tool_calls[0]["args"]) + assert await tool.ainvoke(tool_call["args"]) == "good" + + assert tool_call == { + "name": "sample_tool", + "args": {"foo": 2}, + "id": "call_0_82c17db8-95df-452f-a4c2-03f809022134", + "type": "tool_call", + } + + +def test_tool_invoke_does_not_mutate_inputs() -> None: + """Verify that the inputs are not mutated when invoking a tool synchronously.""" + + def sync_no_op(foo: int) -> str: + return "good" + + async def async_no_op(foo: int) -> str: + return "good" + + tool = StructuredTool( + name="sample_tool", + description="", + args_schema={ + "type": "object", + "required": ["foo"], + "properties": { + "seconds": {"type": "number", "description": "How big is foo"} + }, + }, + coroutine=async_no_op, + func=sync_no_op, + ) + + tool_call = { + "name": "sample_tool", + "args": {"foo": 2}, + "id": "call_0_82c17db8-95df-452f-a4c2-03f809022134", + "type": "tool_call", + } + + assert tool.invoke(tool_call["args"]) == "good" + assert tool_call == { + "name": "sample_tool", + "args": {"foo": 2}, + "id": "call_0_82c17db8-95df-452f-a4c2-03f809022134", + "type": "tool_call", + } - # Assertions - assert result == "good" - assert "run_manager" not in ai_message.tool_calls[0]["args"]