mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 21:12:48 +00:00
multiple: langchain 0.2 in master (#21191)
0.2rc migrations - [x] Move memory - [x] Move remaining retrievers - [x] graph_qa chains - [x] some dependency from evaluation code potentially on math utils - [x] Move openapi chain from `langchain.chains.api.openapi` to `langchain_community.chains.openapi` - [x] Migrate `langchain.chains.ernie_functions` to `langchain_community.chains.ernie_functions` - [x] migrate `langchain/chains/llm_requests.py` to `langchain_community.chains.llm_requests` - [x] Moving `langchain_community.cross_enoders.base:BaseCrossEncoder` -> `langchain_community.retrievers.document_compressors.cross_encoder:BaseCrossEncoder` (namespace not ideal, but it needs to be moved to `langchain` to avoid circular deps) - [x] unit tests langchain -- add pytest.mark.community to some unit tests that will stay in langchain - [x] unit tests community -- move unit tests that depend on community to community - [x] mv integration tests that depend on community to community - [x] mypy checks Other todo - [x] Make deprecation warnings not noisy (need to use warn deprecated and check that things are implemented properly) - [x] Update deprecation messages with timeline for code removal (likely we actually won't be removing things until 0.4 release) -- will give people more time to transition their code. - [ ] Add information to deprecation warning to show users how to migrate their code base using langchain-cli - [ ] Remove any unnecessary requirements in langchain (e.g., is SQLALchemy required?) --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
100
libs/community/tests/unit_tests/agents/test_tools.py
Normal file
100
libs/community/tests/unit_tests/agents/test_tools.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Test tool utils."""
|
||||
import unittest
|
||||
from typing import Any, Type
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.agents.chat.base import ChatAgent
|
||||
from langchain.agents.conversational.base import ConversationalAgent
|
||||
from langchain.agents.conversational_chat.base import ConversationalChatAgent
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||
from langchain_core.tools import Tool, ToolException, tool
|
||||
|
||||
from langchain_community.agent_toolkits.load_tools import load_tools
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_cls",
|
||||
[
|
||||
ZeroShotAgent,
|
||||
ChatAgent,
|
||||
ConversationalChatAgent,
|
||||
ConversationalAgent,
|
||||
ReActDocstoreAgent,
|
||||
ReActTextWorldAgent,
|
||||
SelfAskWithSearchAgent,
|
||||
],
|
||||
)
|
||||
def test_single_input_agent_raises_error_on_structured_tool(
|
||||
agent_cls: Type[Agent],
|
||||
) -> None:
|
||||
"""Test that older agents raise errors on older tools."""
|
||||
|
||||
@tool
|
||||
def the_tool(foo: str, bar: str) -> str:
|
||||
"""Return the concat of foo and bar."""
|
||||
return foo + bar
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"{agent_cls.__name__} does not support" # type: ignore
|
||||
f" multi-input tool {the_tool.name}.",
|
||||
):
|
||||
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
|
||||
|
||||
|
||||
def test_tool_no_args_specified_assumes_str() -> None:
|
||||
"""Older tools could assume *args and **kwargs were passed in."""
|
||||
|
||||
def ambiguous_function(*args: Any, **kwargs: Any) -> str:
|
||||
"""An ambiguously defined function."""
|
||||
return args[0]
|
||||
|
||||
some_tool = Tool(
|
||||
name="chain_run",
|
||||
description="Run the chain",
|
||||
func=ambiguous_function,
|
||||
)
|
||||
expected_args = {"tool_input": {"type": "string"}}
|
||||
assert some_tool.args == expected_args
|
||||
assert some_tool.run("foobar") == "foobar"
|
||||
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
|
||||
with pytest.raises(ToolException, match="Too many arguments to single-input tool"):
|
||||
some_tool.run({"tool_input": "foobar", "other_input": "bar"})
|
||||
|
||||
|
||||
def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
|
||||
"""Test load_tools raises a deprecation for old callback manager kwarg."""
|
||||
callback_manager = MagicMock()
|
||||
with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"):
|
||||
tools = load_tools(
|
||||
["requests_get"],
|
||||
callback_manager=callback_manager,
|
||||
allow_dangerous_tools=True,
|
||||
)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].callbacks == callback_manager
|
||||
|
||||
|
||||
def test_load_tools_with_callbacks_is_called() -> None:
|
||||
"""Test callbacks are called when provided to load_tools fn."""
|
||||
callbacks = [FakeCallbackHandler()]
|
||||
tools = load_tools(
|
||||
["requests_get"], # type: ignore
|
||||
callbacks=callbacks, # type: ignore
|
||||
allow_dangerous_tools=True,
|
||||
)
|
||||
assert len(tools) == 1
|
||||
# Patch the requests.get() method to return a mock response
|
||||
with unittest.mock.patch(
|
||||
"langchain.requests.TextRequestsWrapper.get",
|
||||
return_value=Mock(text="Hello world!"),
|
||||
):
|
||||
result = tools[0].run("https://www.google.com")
|
||||
assert result.text == "Hello world!"
|
||||
assert callbacks[0].tool_starts == 1
|
||||
assert callbacks[0].tool_ends == 1
|
Reference in New Issue
Block a user