mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 12:58:59 +00:00
add a unit test
This commit is contained in:
parent
913c8b71d9
commit
ed2428f902
@ -17,7 +17,7 @@ from typing import (
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
cast, Coroutine,
|
||||
)
|
||||
|
||||
import pytest
|
||||
@ -35,7 +35,7 @@ from langchain_core.callbacks.manager import (
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import ToolCall, ToolMessage
|
||||
from langchain_core.messages import ToolCall, ToolMessage, AIMessage
|
||||
from langchain_core.messages.tool import ToolOutputMixin
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import (
|
||||
@ -2604,3 +2604,59 @@ def test_title_property_preserved() -> None:
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_with_langgraph() -> None:
|
||||
"""Test that tool execution works correctly with LangGraph."""
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage, AIMessage
|
||||
import json
|
||||
from blockbuster.blockbuster import blockbuster_skip
|
||||
|
||||
blockbuster_skip.set(True)
|
||||
|
||||
# Setup
|
||||
prompt = "Help user with his/her requests"
|
||||
|
||||
# Create test tool
|
||||
async def sleep(**arguments: dict[str, Any]) -> str:
|
||||
return "good"
|
||||
|
||||
my_tool = [
|
||||
StructuredTool(
|
||||
name="sleep",
|
||||
description="Sleep for a while",
|
||||
args_schema={
|
||||
"type": "object",
|
||||
"required": ["seconds"],
|
||||
"properties": {
|
||||
"seconds": {"type": "number", "description": "How long to sleep"}
|
||||
},
|
||||
},
|
||||
coroutine=sleep,
|
||||
func=sleep
|
||||
)
|
||||
]
|
||||
_tools = {t.name: t for t in my_tool}
|
||||
|
||||
tool_calls = '''
|
||||
[{
|
||||
"name": "sleep",
|
||||
"args": {"seconds": 2},
|
||||
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
|
||||
"type": "tool_call"}]
|
||||
'''
|
||||
|
||||
# Test execution
|
||||
messages = []
|
||||
_input = "sleep for 2 seconds!"
|
||||
messages.append(SystemMessage(content=prompt))
|
||||
messages.append(HumanMessage(content=_input))
|
||||
ai_message = AIMessage(tool_calls=json.loads(tool_calls), content='')
|
||||
messages.append(ai_message)
|
||||
|
||||
result = await _tools["sleep"].ainvoke(messages[-1].tool_calls[0]["args"])
|
||||
|
||||
# Assertions
|
||||
assert "good" == result
|
||||
assert "run_manager" not in messages[-1].tool_calls[0]["args"]
|
||||
|
Loading…
Reference in New Issue
Block a user