mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
Pass Callbacks through load_tools (#4298)
- Update the load_tools method to properly accept `callbacks` arguments. - Add a deprecation warning when `callback_manager` is passed - Add two unit tests to check the deprecation warning is raised and to confirm the callback is passed through. Closes issue #4096
This commit is contained in:
parent
0870a45a69
commit
35c9e6ab40
@ -7,6 +7,7 @@ from mypy_extensions import Arg, KwArg
|
|||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
|
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
|
||||||
from langchain.chains.api.base import APIChain
|
from langchain.chains.api.base import APIChain
|
||||||
from langchain.chains.llm_math.base import LLMMathChain
|
from langchain.chains.llm_math.base import LLMMathChain
|
||||||
@ -279,10 +280,26 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_callbacks(
|
||||||
|
callback_manager: Optional[BaseCallbackManager], callbacks: Callbacks
|
||||||
|
) -> Callbacks:
|
||||||
|
if callback_manager is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"callback_manager is deprecated. Please use callbacks instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
if callbacks is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both callback_manager and callbacks arguments."
|
||||||
|
)
|
||||||
|
return callback_manager
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
def load_tools(
|
def load_tools(
|
||||||
tool_names: List[str],
|
tool_names: List[str],
|
||||||
llm: Optional[BaseLanguageModel] = None,
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[BaseTool]:
|
) -> List[BaseTool]:
|
||||||
"""Load tools based on their name.
|
"""Load tools based on their name.
|
||||||
@ -290,13 +307,16 @@ def load_tools(
|
|||||||
Args:
|
Args:
|
||||||
tool_names: name of tools to load.
|
tool_names: name of tools to load.
|
||||||
llm: Optional language model, may be needed to initialize certain tools.
|
llm: Optional language model, may be needed to initialize certain tools.
|
||||||
callback_manager: Optional callback manager. If not provided, default global callback manager will be used.
|
callbacks: Optional callback manager or list of callback handlers.
|
||||||
|
If not provided, default global callback manager will be used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of tools.
|
List of tools.
|
||||||
"""
|
"""
|
||||||
tools = []
|
tools = []
|
||||||
|
callbacks = _handle_callbacks(
|
||||||
|
callback_manager=kwargs.get("callback_manager"), callbacks=callbacks
|
||||||
|
)
|
||||||
for name in tool_names:
|
for name in tool_names:
|
||||||
if name == "requests":
|
if name == "requests":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -316,8 +336,6 @@ def load_tools(
|
|||||||
if llm is None:
|
if llm is None:
|
||||||
raise ValueError(f"Tool {name} requires an LLM to be provided")
|
raise ValueError(f"Tool {name} requires an LLM to be provided")
|
||||||
tool = _LLM_TOOLS[name](llm)
|
tool = _LLM_TOOLS[name](llm)
|
||||||
if callback_manager is not None:
|
|
||||||
tool.callback_manager = callback_manager
|
|
||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
elif name in _EXTRA_LLM_TOOLS:
|
elif name in _EXTRA_LLM_TOOLS:
|
||||||
if llm is None:
|
if llm is None:
|
||||||
@ -331,18 +349,17 @@ def load_tools(
|
|||||||
)
|
)
|
||||||
sub_kwargs = {k: kwargs[k] for k in extra_keys}
|
sub_kwargs = {k: kwargs[k] for k in extra_keys}
|
||||||
tool = _get_llm_tool_func(llm=llm, **sub_kwargs)
|
tool = _get_llm_tool_func(llm=llm, **sub_kwargs)
|
||||||
if callback_manager is not None:
|
|
||||||
tool.callback_manager = callback_manager
|
|
||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
elif name in _EXTRA_OPTIONAL_TOOLS:
|
elif name in _EXTRA_OPTIONAL_TOOLS:
|
||||||
_get_tool_func, extra_keys = _EXTRA_OPTIONAL_TOOLS[name]
|
_get_tool_func, extra_keys = _EXTRA_OPTIONAL_TOOLS[name]
|
||||||
sub_kwargs = {k: kwargs[k] for k in extra_keys if k in kwargs}
|
sub_kwargs = {k: kwargs[k] for k in extra_keys if k in kwargs}
|
||||||
tool = _get_tool_func(**sub_kwargs)
|
tool = _get_tool_func(**sub_kwargs)
|
||||||
if callback_manager is not None:
|
|
||||||
tool.callback_manager = callback_manager
|
|
||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown tool {name}")
|
raise ValueError(f"Got unknown tool {name}")
|
||||||
|
if callbacks is not None:
|
||||||
|
for tool in tools:
|
||||||
|
tool.callbacks = callbacks
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
"""Test tool utils."""
|
"""Test tool utils."""
|
||||||
|
import unittest
|
||||||
from typing import Any, Type
|
from typing import Any, Type
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from langchain.agents import load_tools
|
||||||
from langchain.agents.agent import Agent
|
from langchain.agents.agent import Agent
|
||||||
from langchain.agents.chat.base import ChatAgent
|
from langchain.agents.chat.base import ChatAgent
|
||||||
from langchain.agents.conversational.base import ConversationalAgent
|
from langchain.agents.conversational.base import ConversationalAgent
|
||||||
@ -12,6 +14,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
|
|||||||
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
|
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
|
||||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||||
from langchain.agents.tools import Tool, tool
|
from langchain.agents.tools import Tool, tool
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -62,3 +65,28 @@ def test_tool_no_args_specified_assumes_str() -> None:
|
|||||||
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
|
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
|
||||||
with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
|
with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
|
||||||
some_tool.run({"tool_input": "foobar", "other_input": "bar"})
|
some_tool.run({"tool_input": "foobar", "other_input": "bar"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
|
||||||
|
"""Test load_tools raises a deprecation for old callback manager kwarg."""
|
||||||
|
callback_manager = MagicMock()
|
||||||
|
with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"):
|
||||||
|
tools = load_tools(["requests_get"], callback_manager=callback_manager)
|
||||||
|
assert len(tools) == 1
|
||||||
|
assert tools[0].callbacks == callback_manager
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_tools_with_callbacks_is_called() -> None:
|
||||||
|
"""Test callbacks are called when provided to load_tools fn."""
|
||||||
|
callbacks = [FakeCallbackHandler()]
|
||||||
|
tools = load_tools(["requests_get"], callbacks=callbacks) # type: ignore
|
||||||
|
assert len(tools) == 1
|
||||||
|
# Patch the requests.get() method to return a mock response
|
||||||
|
with unittest.mock.patch(
|
||||||
|
"langchain.requests.TextRequestsWrapper.get",
|
||||||
|
return_value=Mock(text="Hello world!"),
|
||||||
|
):
|
||||||
|
result = tools[0].run("https://www.google.com")
|
||||||
|
assert result.text == "Hello world!"
|
||||||
|
assert callbacks[0].tool_starts == 1
|
||||||
|
assert callbacks[0].tool_ends == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user