core[patch]: fix mutating tool calls (#24677)

In some cases tool calls are mutated when passed through a tool.
This commit is contained in:
ccurme 2024-07-25 12:46:36 -04:00 committed by GitHub
parent dfbd12b384
commit 58dd69f7f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 2 deletions

View File

@ -20,6 +20,7 @@ tool for the job.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import copy
import functools import functools
import inspect import inspect
import json import json
@ -1481,8 +1482,9 @@ def _prep_run_args(
) -> Tuple[Union[str, Dict], Dict]: ) -> Tuple[Union[str, Dict], Dict]:
config = ensure_config(config) config = ensure_config(config)
if _is_tool_call(input): if _is_tool_call(input):
tool_call_id: Optional[str] = cast(ToolCall, input)["id"] input_copy = copy.deepcopy(input)
tool_input: Union[str, dict] = cast(ToolCall, input)["args"] tool_call_id: Optional[str] = cast(ToolCall, input_copy)["id"]
tool_input: Union[str, dict] = cast(ToolCall, input_copy)["args"]
else: else:
tool_call_id = None tool_call_id = None
tool_input = cast(Union[str, dict], input) tool_input = cast(Union[str, dict], input)

View File

@ -977,6 +977,16 @@ class AFooBase(FooBase):
def test_tool_pass_config(tool: BaseTool) -> None: def test_tool_pass_config(tool: BaseTool) -> None:
assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"
# Test tool calls
tool_call = {
"name": tool.name,
"args": {"bar": "baz"},
"id": "abc123",
"type": "tool_call",
}
_ = tool.invoke(tool_call, {"configurable": {"foo": "not-bar"}})
assert tool_call["args"] == {"bar": "baz"}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()] "tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]