mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 02:03:44 +00:00
core[patch]: support tool calls with non-pickleable args in tools (#24741)
Deepcopy raises with non-pickleable args.
This commit is contained in:
parent
df78608741
commit
9998e55936
@ -20,7 +20,6 @@ tool for the job.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
@ -1491,9 +1490,8 @@ def _prep_run_args(
|
|||||||
) -> Tuple[Union[str, Dict], Dict]:
|
) -> Tuple[Union[str, Dict], Dict]:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
if _is_tool_call(input):
|
if _is_tool_call(input):
|
||||||
input_copy = copy.deepcopy(input)
|
tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
|
||||||
tool_call_id: Optional[str] = cast(ToolCall, input_copy)["id"]
|
tool_input: Union[str, dict] = cast(ToolCall, input)["args"].copy()
|
||||||
tool_input: Union[str, dict] = cast(ToolCall, input_copy)["args"]
|
|
||||||
else:
|
else:
|
||||||
tool_call_id = None
|
tool_call_id = None
|
||||||
tool_input = cast(Union[str, dict], input)
|
tool_input = cast(Union[str, dict], input)
|
||||||
|
@ -4,6 +4,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -977,7 +978,7 @@ class AFooBase(FooBase):
|
|||||||
def test_tool_pass_config(tool: BaseTool) -> None:
|
def test_tool_pass_config(tool: BaseTool) -> None:
|
||||||
assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"
|
assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"
|
||||||
|
|
||||||
# Test tool calls
|
# Test we don't mutate tool calls
|
||||||
tool_call = {
|
tool_call = {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"args": {"bar": "baz"},
|
"args": {"bar": "baz"},
|
||||||
@ -988,6 +989,25 @@ def test_tool_pass_config(tool: BaseTool) -> None:
|
|||||||
assert tool_call["args"] == {"bar": "baz"}
|
assert tool_call["args"] == {"bar": "baz"}
|
||||||
|
|
||||||
|
|
||||||
|
class FooBaseNonPickleable(FooBase):
|
||||||
|
def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_pass_config_non_pickleable() -> None:
|
||||||
|
tool = FooBaseNonPickleable()
|
||||||
|
|
||||||
|
args = {"bar": threading.Lock()}
|
||||||
|
tool_call = {
|
||||||
|
"name": tool.name,
|
||||||
|
"args": args,
|
||||||
|
"id": "abc123",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
_ = tool.invoke(tool_call, {"configurable": {"foo": "not-bar"}})
|
||||||
|
assert tool_call["args"] == args
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]
|
"tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user