mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 20:35:43 +00:00
x
This commit is contained in:
parent
92dc3f7341
commit
6d22f40a0b
@ -650,8 +650,16 @@ class ChildTool(BaseTool):
|
|||||||
# pass as a positional argument.
|
# pass as a positional argument.
|
||||||
if isinstance(tool_input, str):
|
if isinstance(tool_input, str):
|
||||||
return (tool_input,), {}
|
return (tool_input,), {}
|
||||||
|
elif isinstance(tool_input, dict):
|
||||||
|
# Make a shallow copy of the input to allow downstream code
|
||||||
|
# to modify the root level of the input without affecting the
|
||||||
|
# original input.
|
||||||
|
# This is used by the tool to inject run time information like
|
||||||
|
# the callback manager.
|
||||||
|
return (), tool_input.copy()
|
||||||
else:
|
else:
|
||||||
return (), tool_input
|
# This code path is not expected to be reachable.
|
||||||
|
raise TypeError(f"Invalid tool input type: {type(tool_input)}")
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
@ -946,14 +954,6 @@ def _prep_run_args(
|
|||||||
else:
|
else:
|
||||||
tool_call_id = None
|
tool_call_id = None
|
||||||
tool_input = cast(Union[str, dict], input)
|
tool_input = cast(Union[str, dict], input)
|
||||||
if not isinstance(tool_input, str):
|
|
||||||
try:
|
|
||||||
tool_input = tool_input.copy()
|
|
||||||
except Exception:
|
|
||||||
import copy
|
|
||||||
|
|
||||||
tool_input = copy.deepcopy(tool_input)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
tool_input,
|
tool_input,
|
||||||
dict(
|
dict(
|
||||||
|
@ -2,12 +2,16 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import pytest
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
from pydantic.v1 import ValidationError as ValidationErrorV1
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
@ -19,11 +23,6 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
|
||||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain_core import tools
|
from langchain_core import tools
|
||||||
@ -35,7 +34,7 @@ from langchain_core.callbacks.manager import (
|
|||||||
CallbackManagerForRetrieverRun,
|
CallbackManagerForRetrieverRun,
|
||||||
)
|
)
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
from langchain_core.messages import ToolCall, ToolMessage
|
||||||
from langchain_core.messages.tool import ToolOutputMixin
|
from langchain_core.messages.tool import ToolOutputMixin
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
@ -2606,62 +2605,89 @@ def test_title_property_preserved() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
||||||
async def test_ainvoke_input_isolation_from_state_graph_context() -> None:
|
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
|
||||||
"""Test that tool execution works correctly with LangGraph."""
|
|
||||||
import json
|
|
||||||
|
|
||||||
from blockbuster.blockbuster import blockbuster_skip
|
def sync_no_op(foo: int) -> str:
|
||||||
|
|
||||||
from langchain_core.messages import (
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
blockbuster_skip.set(True)
|
|
||||||
|
|
||||||
# Setup
|
|
||||||
prompt = "Help user with his/her requests"
|
|
||||||
|
|
||||||
# Create test tool
|
|
||||||
async def sleep(**arguments: dict[str, Any]) -> str:
|
|
||||||
return "good"
|
return "good"
|
||||||
|
|
||||||
my_tool = [
|
async def async_no_op(foo: int) -> str:
|
||||||
StructuredTool(
|
return "good"
|
||||||
name="sleep",
|
|
||||||
description="Sleep for a while",
|
tool = StructuredTool(
|
||||||
args_schema={
|
name="sample_tool",
|
||||||
"type": "object",
|
description="",
|
||||||
"required": ["seconds"],
|
args_schema={
|
||||||
"properties": {
|
"type": "object",
|
||||||
"seconds": {"type": "number", "description": "How long to sleep"}
|
"required": ["foo"],
|
||||||
},
|
"properties": {
|
||||||
|
"seconds": {"type": "number", "description": "How big is foo"}
|
||||||
},
|
},
|
||||||
coroutine=sleep,
|
},
|
||||||
func=sleep,
|
coroutine=async_no_op,
|
||||||
)
|
func=sync_no_op,
|
||||||
]
|
)
|
||||||
_tools = {t.name: t for t in my_tool}
|
|
||||||
|
|
||||||
tool_calls = """
|
tool_call = {
|
||||||
[{
|
"name": "sample_tool",
|
||||||
"name": "sleep",
|
"args": {"foo": 2},
|
||||||
"args": {"seconds": 2},
|
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
|
||||||
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
|
"type": "tool_call",
|
||||||
"type": "tool_call"}]
|
}
|
||||||
"""
|
|
||||||
|
|
||||||
# Test execution
|
assert tool.invoke(tool_call["args"]) == "good"
|
||||||
messages: list[Any] = []
|
assert tool_call == {
|
||||||
_input = "sleep for 2 seconds!"
|
"name": "sample_tool",
|
||||||
messages.append(SystemMessage(content=prompt))
|
"args": {"foo": 2},
|
||||||
messages.append(HumanMessage(content=_input))
|
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
|
||||||
ai_message = AIMessage(tool_calls=json.loads(tool_calls), content="")
|
"type": "tool_call",
|
||||||
messages.append(ai_message)
|
}
|
||||||
|
|
||||||
result = await _tools["sleep"].ainvoke(ai_message.tool_calls[0]["args"])
|
assert await tool.ainvoke(tool_call["args"]) == "good"
|
||||||
|
|
||||||
|
assert tool_call == {
|
||||||
|
"name": "sample_tool",
|
||||||
|
"args": {"foo": 2},
|
||||||
|
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_invoke_does_not_mutate_inputs() -> None:
|
||||||
|
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
|
||||||
|
|
||||||
|
def sync_no_op(foo: int) -> str:
|
||||||
|
return "good"
|
||||||
|
|
||||||
|
async def async_no_op(foo: int) -> str:
|
||||||
|
return "good"
|
||||||
|
|
||||||
|
tool = StructuredTool(
|
||||||
|
name="sample_tool",
|
||||||
|
description="",
|
||||||
|
args_schema={
|
||||||
|
"type": "object",
|
||||||
|
"required": ["foo"],
|
||||||
|
"properties": {
|
||||||
|
"seconds": {"type": "number", "description": "How big is foo"}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
coroutine=async_no_op,
|
||||||
|
func=sync_no_op,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call = {
|
||||||
|
"name": "sample_tool",
|
||||||
|
"args": {"foo": 2},
|
||||||
|
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert tool.invoke(tool_call["args"]) == "good"
|
||||||
|
assert tool_call == {
|
||||||
|
"name": "sample_tool",
|
||||||
|
"args": {"foo": 2},
|
||||||
|
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert result == "good"
|
|
||||||
assert "run_manager" not in ai_message.tool_calls[0]["args"]
|
|
||||||
|
Loading…
Reference in New Issue
Block a user