From 6eb6c45c981dd8d04dfeb7ac6becdb0f6b863728 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 28 Sep 2023 15:40:22 +0100 Subject: [PATCH] Enable creating Tools from any Runnable --- libs/langchain/langchain/tools/base.py | 22 ++++++++++++++---- .../schema/runnable/test_runnable.py | 23 +++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 2310927ac25..160071be3e9 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -734,7 +734,7 @@ class StructuredTool(BaseTool): def tool( - *args: Union[str, Callable], + *args: Union[str, Callable, Runnable], return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, infer_schema: bool = True, @@ -769,21 +769,31 @@ def tool( """ def _make_with_name(tool_name: str) -> Callable: - def _make_tool(dec_func: Callable) -> BaseTool: - if inspect.iscoroutinefunction(dec_func): + def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: + if isinstance(dec_func, Runnable): + coroutine = dec_func.ainvoke + func = dec_func.invoke + schema = dec_func.input_schema + description = repr(dec_func) + elif inspect.iscoroutinefunction(dec_func): coroutine = dec_func func = None + schema = args_schema + description = None else: coroutine = None func = dec_func + schema = args_schema + description = None if infer_schema or args_schema is not None: return StructuredTool.from_function( func, coroutine, name=tool_name, + description=description, return_direct=return_direct, - args_schema=args_schema, + args_schema=schema, infer_schema=infer_schema, ) # If someone doesn't want a schema applied, we must treat it as @@ -803,7 +813,9 @@ def tool( return _make_tool - if len(args) == 1 and isinstance(args[0], str): + if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): + return _make_with_name(args[0])(args[1]) + elif len(args) == 1 and isinstance(args[0], str): # if the argument is a string, then we use the string as the tool name # Example usage: @tool("search", return_direct=True) return _make_with_name(args[0]) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index a9103d37253..d46bd8df8e5 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -2,6 +2,7 @@ import sys from operator import itemgetter from typing import Any, Dict, List, Optional, Sequence, Union, cast from uuid import UUID +from langchain.tools.base import BaseTool, tool import pytest from freezegun import freeze_time @@ -2779,3 +2780,25 @@ def test_representation_of_runnables() -> None: " b: RunnableLambda(...)\n" " }" ), "repr where code string contains multiple lambdas gives up" + + +def test_tool_from_runnable() -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + llm = FakeStreamingListLLM(responses=["foo-lish"]) + + chain = prompt | llm | StrOutputParser() + + chain_tool = tool("chain_tool", chain) + + assert isinstance(chain_tool, BaseTool) + assert chain_tool.name == "chain_tool" + assert chain_tool.description.endswith(repr(chain)) + assert chain_tool.args_schema.schema() == chain.input_schema.schema() + assert chain_tool.args_schema.schema() == { + "properties": {"question": {"title": "Question"}}, + "title": "PromptInput", + "type": "object", + }