core[patch]: support tool calls with non-pickleable args in tools (#24741)

Deepcopy raises with non-pickleable args.
This commit is contained in:
ccurme 2024-07-29 13:18:39 -04:00 committed by GitHub
parent df78608741
commit 9998e55936
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 5 deletions

View File

@ -20,7 +20,6 @@ tool for the job.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import copy
import functools import functools
import inspect import inspect
import json import json
@ -1491,9 +1490,8 @@ def _prep_run_args(
) -> Tuple[Union[str, Dict], Dict]: ) -> Tuple[Union[str, Dict], Dict]:
config = ensure_config(config) config = ensure_config(config)
if _is_tool_call(input): if _is_tool_call(input):
input_copy = copy.deepcopy(input) tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
tool_call_id: Optional[str] = cast(ToolCall, input_copy)["id"] tool_input: Union[str, dict] = cast(ToolCall, input)["args"].copy()
tool_input: Union[str, dict] = cast(ToolCall, input_copy)["args"]
else: else:
tool_call_id = None tool_call_id = None
tool_input = cast(Union[str, dict], input) tool_input = cast(Union[str, dict], input)

View File

@ -4,6 +4,7 @@ import inspect
import json import json
import sys import sys
import textwrap import textwrap
import threading
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from functools import partial from functools import partial
@ -977,7 +978,7 @@ class AFooBase(FooBase):
def test_tool_pass_config(tool: BaseTool) -> None: def test_tool_pass_config(tool: BaseTool) -> None:
assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"
# Test tool calls # Test we don't mutate tool calls
tool_call = { tool_call = {
"name": tool.name, "name": tool.name,
"args": {"bar": "baz"}, "args": {"bar": "baz"},
@ -988,6 +989,25 @@ def test_tool_pass_config(tool: BaseTool) -> None:
assert tool_call["args"] == {"bar": "baz"} assert tool_call["args"] == {"bar": "baz"}
class FooBaseNonPickleable(FooBase):
def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
return True
def test_tool_pass_config_non_pickleable() -> None:
tool = FooBaseNonPickleable()
args = {"bar": threading.Lock()}
tool_call = {
"name": tool.name,
"args": args,
"id": "abc123",
"type": "tool_call",
}
_ = tool.invoke(tool_call, {"configurable": {"foo": "not-bar"}})
assert tool_call["args"] == args
@pytest.mark.parametrize( @pytest.mark.parametrize(
"tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()] "tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]
) )