diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 18f11bdc26e..cb531e4ecc6 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -17,7 +17,7 @@ from typing import ( Optional, TypeVar, Union, - cast, + cast, Coroutine, ) import pytest @@ -35,7 +35,7 @@ from langchain_core.callbacks.manager import ( CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document -from langchain_core.messages import ToolCall, ToolMessage +from langchain_core.messages import ToolCall, ToolMessage, AIMessage from langchain_core.messages.tool import ToolOutputMixin from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import ( @@ -2604,3 +2604,59 @@ def test_title_property_preserved() -> None: }, "type": "function", } + + +@pytest.mark.asyncio +async def test_tool_with_langgraph() -> None: + """Test that tool execution works correctly with LangGraph.""" + from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage, AIMessage + import json + from blockbuster.blockbuster import blockbuster_skip + + blockbuster_skip.set(True) + + # Setup + prompt = "Help user with his/her requests" + + # Create test tool + async def sleep(**arguments: dict[str, Any]) -> str: + return "good" + + my_tool = [ + StructuredTool( + name="sleep", + description="Sleep for a while", + args_schema={ + "type": "object", + "required": ["seconds"], + "properties": { + "seconds": {"type": "number", "description": "How long to sleep"} + }, + }, + coroutine=sleep, + func=sleep + ) + ] + _tools = {t.name: t for t in my_tool} + + tool_calls = ''' + [{ + "name": "sleep", + "args": {"seconds": 2}, + "id": "call_0_82c17db8-95df-452f-a4c2-03f809022134", + "type": "tool_call"}] + ''' + + # Test execution + messages = [] + _input = "sleep for 2 seconds!" + messages.append(SystemMessage(content=prompt)) + messages.append(HumanMessage(content=_input)) + ai_message = AIMessage(tool_calls=json.loads(tool_calls), content='') + messages.append(ai_message) + + result = await _tools["sleep"].ainvoke(messages[-1].tool_calls[0]["args"]) + + # Assertions + assert "good" == result + assert "run_manager" not in messages[-1].tool_calls[0]["args"]