mirror of
https://github.com/hwchase17/langchain.git
synced 2025-11-23 17:06:54 +00:00
```python
"""python scripts/update_mypy_ruff.py"""
import glob
import tomllib
from pathlib import Path
import toml
import subprocess
import re
ROOT_DIR = Path(__file__).parents[1]
def main():
for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True):
print(path)
with open(path, "rb") as f:
pyproject = tomllib.load(f)
try:
pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = (
"^1.10"
)
pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = (
"^0.5"
)
except KeyError:
continue
with open(path, "w") as f:
toml.dump(pyproject, f)
cwd = "/".join(path.split("/")[:-1])
completed = subprocess.run(
"poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color",
cwd=cwd,
shell=True,
capture_output=True,
text=True,
)
logs = completed.stdout.split("\n")
to_ignore = {}
for l in logs:
if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l):
path, line_no, error_type = re.match(
"^(.*)\:(\d+)\: error:.*\[(.*)\]", l
).groups()
if (path, line_no) in to_ignore:
to_ignore[(path, line_no)].append(error_type)
else:
to_ignore[(path, line_no)] = [error_type]
print(len(to_ignore))
for (error_path, line_no), error_types in to_ignore.items():
all_errors = ", ".join(error_types)
full_path = f"{cwd}/{error_path}"
try:
with open(full_path, "r") as f:
file_lines = f.readlines()
except FileNotFoundError:
continue
file_lines[int(line_no) - 1] = (
file_lines[int(line_no) - 1][:-1] + f" # type: ignore[{all_errors}]\n"
)
with open(full_path, "w") as f:
f.write("".join(file_lines))
subprocess.run(
"poetry run ruff format .; poetry run ruff --select I --fix .",
cwd=cwd,
shell=True,
capture_output=True,
text=True,
)
if __name__ == "__main__":
main()
```
102 lines
3.5 KiB
Python
102 lines
3.5 KiB
Python
"""Test tool utils."""
|
|
|
|
import unittest
|
|
from typing import Any, Type
|
|
from unittest.mock import MagicMock, Mock
|
|
|
|
import pytest
|
|
from langchain.agents.agent import Agent
|
|
from langchain.agents.chat.base import ChatAgent
|
|
from langchain.agents.conversational.base import ConversationalAgent
|
|
from langchain.agents.conversational_chat.base import ConversationalChatAgent
|
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
|
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
|
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
|
from langchain_core.tools import Tool, ToolException, tool
|
|
|
|
from langchain_community.agent_toolkits.load_tools import load_tools
|
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"agent_cls",
|
|
[
|
|
ZeroShotAgent,
|
|
ChatAgent,
|
|
ConversationalChatAgent,
|
|
ConversationalAgent,
|
|
ReActDocstoreAgent,
|
|
ReActTextWorldAgent,
|
|
SelfAskWithSearchAgent,
|
|
],
|
|
)
|
|
def test_single_input_agent_raises_error_on_structured_tool(
|
|
agent_cls: Type[Agent],
|
|
) -> None:
|
|
"""Test that older agents raise errors on older tools."""
|
|
|
|
@tool
|
|
def the_tool(foo: str, bar: str) -> str:
|
|
"""Return the concat of foo and bar."""
|
|
return foo + bar
|
|
|
|
with pytest.raises(
|
|
ValueError,
|
|
match=f"{agent_cls.__name__} does not support" # type: ignore
|
|
f" multi-input tool the_tool.",
|
|
):
|
|
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
|
|
|
|
|
|
def test_tool_no_args_specified_assumes_str() -> None:
|
|
"""Older tools could assume *args and **kwargs were passed in."""
|
|
|
|
def ambiguous_function(*args: Any, **kwargs: Any) -> str:
|
|
"""An ambiguously defined function."""
|
|
return args[0]
|
|
|
|
some_tool = Tool(
|
|
name="chain_run",
|
|
description="Run the chain",
|
|
func=ambiguous_function,
|
|
)
|
|
expected_args = {"tool_input": {"type": "string"}}
|
|
assert some_tool.args == expected_args
|
|
assert some_tool.run("foobar") == "foobar"
|
|
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
|
|
with pytest.raises(ToolException, match="Too many arguments to single-input tool"):
|
|
some_tool.run({"tool_input": "foobar", "other_input": "bar"})
|
|
|
|
|
|
def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
|
|
"""Test load_tools raises a deprecation for old callback manager kwarg."""
|
|
callback_manager = MagicMock()
|
|
with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"):
|
|
tools = load_tools(
|
|
["requests_get"],
|
|
callback_manager=callback_manager,
|
|
allow_dangerous_tools=True,
|
|
)
|
|
assert len(tools) == 1
|
|
assert tools[0].callbacks == callback_manager
|
|
|
|
|
|
def test_load_tools_with_callbacks_is_called() -> None:
|
|
"""Test callbacks are called when provided to load_tools fn."""
|
|
callbacks = [FakeCallbackHandler()]
|
|
tools = load_tools(
|
|
["requests_get"], # type: ignore
|
|
callbacks=callbacks, # type: ignore
|
|
allow_dangerous_tools=True,
|
|
)
|
|
assert len(tools) == 1
|
|
# Patch the requests.get() method to return a mock response
|
|
with unittest.mock.patch(
|
|
"langchain.requests.TextRequestsWrapper.get",
|
|
return_value=Mock(text="Hello world!"),
|
|
):
|
|
result = tools[0].run("https://www.google.com")
|
|
assert result.text == "Hello world!"
|
|
assert callbacks[0].tool_starts == 1
|
|
assert callbacks[0].tool_ends == 1
|