mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
Make Tools own model, add ToolKit Concept (#1095)
Follow-up of @hinthornw's PR: - Migrate the Tool abstraction to a separate file (`BaseTool`). - `Tool` implementation of `BaseTool` takes in function and coroutine to more easily maintain backwards compatibility - Add a Toolkit abstraction that can own the generation of tools around a shared concept or state --------- Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Francisco Ingham <fpingham@gmail.com> Co-authored-by: Dhruv Anand <105786647+dhruv-anand-aintech@users.noreply.github.com> Co-authored-by: cragwolfe <cragcw@gmail.com> Co-authored-by: Anton Troynikov <atroyn@users.noreply.github.com> Co-authored-by: Oliver Klingefjord <oliver@klingefjord.com> Co-authored-by: William Fu-Hinthorn <whinthorn@Williams-MBP-3.attlocal.net> Co-authored-by: Bruno Bornsztein <bruno.bornsztein@gmail.com>
This commit is contained in:
@@ -4,7 +4,8 @@ from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents import AgentExecutor, Tool, initialize_agent
|
||||
from langchain.agents import AgentExecutor, initialize_agent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.llms.base import LLM
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
@@ -42,8 +43,16 @@ def _get_agent(**kwargs: Any) -> AgentExecutor:
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching"),
|
||||
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
),
|
||||
Tool(
|
||||
name="Lookup",
|
||||
func=lambda x: x,
|
||||
description="Useful for looking up things in a table",
|
||||
),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools, fake_llm, agent="zero-shot-react-description", verbose=True, **kwargs
|
||||
@@ -79,7 +88,12 @@ def test_agent_with_callbacks_global() -> None:
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching"),
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
callback_manager=manager,
|
||||
),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
@@ -118,7 +132,12 @@ def test_agent_with_callbacks_local() -> None:
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching"),
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
callback_manager=manager,
|
||||
),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
@@ -159,7 +178,11 @@ def test_agent_with_callbacks_not_verbose() -> None:
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching"),
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
@@ -186,7 +209,12 @@ def test_agent_tool_return_direct() -> None:
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching", return_direct=True),
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
return_direct=True,
|
||||
),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
@@ -204,7 +232,12 @@ def test_agent_with_new_prefix_suffix() -> None:
|
||||
responses=["FooBarBaz\nAction: Search\nAction Input: misalignment"]
|
||||
)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching", return_direct=True),
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
return_direct=True,
|
||||
),
|
||||
]
|
||||
prefix = "FooBarBaz"
|
||||
|
||||
|
@@ -58,8 +58,8 @@ def test_predict_until_observation_normal() -> None:
|
||||
outputs = ["foo\nAction 1: Search[foo]"]
|
||||
fake_llm = FakeListLLM(responses=outputs)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x),
|
||||
Tool("Lookup", lambda x: x),
|
||||
Tool(name="Search", func=lambda x: x, description="foo"),
|
||||
Tool(name="Lookup", func=lambda x: x, description="bar"),
|
||||
]
|
||||
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
|
||||
output = agent.plan([], input="")
|
||||
@@ -72,8 +72,8 @@ def test_predict_until_observation_repeat() -> None:
|
||||
outputs = ["foo", " Search[foo]"]
|
||||
fake_llm = FakeListLLM(responses=outputs)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x),
|
||||
Tool("Lookup", lambda x: x),
|
||||
Tool(name="Search", func=lambda x: x, description="foo"),
|
||||
Tool(name="Lookup", func=lambda x: x, description="bar"),
|
||||
]
|
||||
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
|
||||
output = agent.plan([], input="")
|
||||
|
@@ -2,6 +2,7 @@
|
||||
import pytest
|
||||
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
@@ -65,3 +66,42 @@ def test_missing_docstring() -> None:
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
return "API result"
|
||||
|
||||
|
||||
def test_create_tool_posistional_args() -> None:
|
||||
"""Test that positional arguments are allowed."""
|
||||
test_tool = Tool("test_name", lambda x: x, "test_description")
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
|
||||
|
||||
def test_create_tool_keyword_args() -> None:
|
||||
"""Test that keyword arguments are allowed."""
|
||||
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_async_tool() -> None:
|
||||
"""Test that async tools are allowed."""
|
||||
|
||||
async def _test_func(x: str) -> str:
|
||||
return x
|
||||
|
||||
test_tool = Tool(
|
||||
name="test_name",
|
||||
func=lambda x: x,
|
||||
description="test_description",
|
||||
coroutine=_test_func,
|
||||
)
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.coroutine is not None
|
||||
assert (
|
||||
await test_tool.arun(AgentAction(tool_input="foo", tool="test_name", log=""))
|
||||
== "foo"
|
||||
)
|
||||
|
Reference in New Issue
Block a user