diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 160071be3e9..b56c1b08826 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -771,8 +771,21 @@ def tool( def _make_with_name(tool_name: str) -> Callable: def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: if isinstance(dec_func, Runnable): - coroutine = dec_func.ainvoke - func = dec_func.invoke + if dec_func.input_schema.schema().get("type") != "object": + raise ValueError("Runnable must have an object schema.") + + async def ainvoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return await dec_func.ainvoke(kwargs, {"callbacks": callbacks}) + + def invoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return dec_func.invoke(kwargs, {"callbacks": callbacks}) + + coroutine = ainvoke_wrapper + func = invoke_wrapper schema = dec_func.input_schema description = repr(dec_func) elif inspect.iscoroutinefunction(dec_func): diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index d46bd8df8e5..e51f55d72fb 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -2782,7 +2782,8 @@ def test_representation_of_runnables() -> None: ), "repr where code string contains multiple lambdas gives up" -def test_tool_from_runnable() -> None: +@pytest.mark.asyncio +async def test_tool_from_runnable() -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") + "{question}" @@ -2795,6 +2796,12 @@ def test_tool_from_runnable() -> None: assert isinstance(chain_tool, BaseTool) assert chain_tool.name == "chain_tool" + assert chain_tool.run({"question": "What up"}) == chain.invoke( + {"question": "What up"} + ) + assert await chain_tool.arun({"question": "What up"}) == await chain.ainvoke( + {"question": "What up"} + ) assert chain_tool.description.endswith(repr(chain)) assert chain_tool.args_schema.schema() == chain.input_schema.schema() assert chain_tool.args_schema.schema() == {