diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 4a552f301a5..cb65bf6165d 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -20,6 +20,7 @@ tool for the job. from __future__ import annotations import asyncio +import functools import inspect import json import textwrap @@ -548,6 +549,9 @@ class ChildTool(BaseTool): tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) if signature(self._run).parameters.get("run_manager"): tool_kwargs["run_manager"] = run_manager + + if config_param := _get_runnable_config_param(self._run): + tool_kwargs[config_param] = config response = context.run(self._run, *tool_args, **tool_kwargs) if self.response_format == "content_and_raw_output": if not isinstance(response, tuple) or len(response) != 2: @@ -627,10 +631,14 @@ class ChildTool(BaseTool): child_config = patch_config(config, callbacks=run_manager.get_child()) context = copy_context() context.run(_set_config_context, child_config) - if self.__class__._arun is BaseTool._arun or signature( - self._arun - ).parameters.get("run_manager"): + func_to_check = ( + self._run if self.__class__._arun is BaseTool._arun else self._arun + ) + if signature(func_to_check).parameters.get("run_manager"): tool_kwargs["run_manager"] = run_manager + if config_param := _get_runnable_config_param(func_to_check): + tool_kwargs[config_param] = config + coro = context.run(self._arun, *tool_args, **tool_kwargs) if accepts_context(asyncio.create_task): response = await asyncio.create_task(coro, context=context) # type: ignore @@ -724,6 +732,7 @@ class Tool(BaseTool): def _run( self, *args: Any, + config: RunnableConfig, run_manager: Optional[CallbackManagerForToolRun] = None, **kwargs: Any, ) -> Any: @@ -731,12 +740,15 @@ class Tool(BaseTool): if self.func: if run_manager and signature(self.func).parameters.get("callbacks"): kwargs["callbacks"] = run_manager.get_child() + if config_param := _get_runnable_config_param(self.func): + kwargs[config_param] = config return self.func(*args, **kwargs) raise NotImplementedError("Tool does not support sync invocation.") async def _arun( self, *args: Any, + config: RunnableConfig, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, **kwargs: Any, ) -> Any: @@ -744,11 +756,15 @@ class Tool(BaseTool): if self.coroutine: if run_manager and signature(self.coroutine).parameters.get("callbacks"): kwargs["callbacks"] = run_manager.get_child() + if config_param := _get_runnable_config_param(self.coroutine): + kwargs[config_param] = config return await self.coroutine(*args, **kwargs) # NOTE: this code is unreachable since _arun is only called if coroutine is not # None. - return await super()._arun(*args, run_manager=run_manager, **kwargs) + return await super()._arun( + *args, config=config, run_manager=run_manager, **kwargs + ) # TODO: this is for backwards compatibility, remove in future def __init__( @@ -822,6 +838,7 @@ class StructuredTool(BaseTool): def _run( self, *args: Any, + config: RunnableConfig, run_manager: Optional[CallbackManagerForToolRun] = None, **kwargs: Any, ) -> Any: @@ -829,12 +846,15 @@ class StructuredTool(BaseTool): if self.func: if run_manager and signature(self.func).parameters.get("callbacks"): kwargs["callbacks"] = run_manager.get_child() + if config_param := _get_runnable_config_param(self.func): + kwargs[config_param] = config return self.func(*args, **kwargs) raise NotImplementedError("StructuredTool does not support sync invocation.") async def _arun( self, *args: Any, + config: RunnableConfig, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, **kwargs: Any, ) -> Any: @@ -842,11 +862,15 @@ class StructuredTool(BaseTool): if self.coroutine: if run_manager and signature(self.coroutine).parameters.get("callbacks"): kwargs["callbacks"] = run_manager.get_child() + if config_param := _get_runnable_config_param(self.coroutine): + kwargs[config_param] = config return await self.coroutine(*args, **kwargs) # NOTE: this code is unreachable since _arun is only called if coroutine is not # None. - return await super()._arun(*args, run_manager=run_manager, **kwargs) + return await super()._arun( + *args, config=config, run_manager=run_manager, **kwargs + ) @classmethod def from_function( @@ -923,12 +947,21 @@ class StructuredTool(BaseTool): description_ = f"{description_.strip()}" _args_schema = args_schema if _args_schema is None and infer_schema: + if config_param := _get_runnable_config_param(source_function): + filter_args: Tuple[str, ...] = ( + config_param, + "run_manager", + "callbacks", + ) + else: + filter_args = ("run_manager", "callbacks") # schema name is appended within function _args_schema = create_schema_from_function( name, source_function, parse_docstring=parse_docstring, error_on_invalid_docstring=error_on_invalid_docstring, + filter_args=filter_args, ) return cls( name=name, @@ -1112,7 +1145,7 @@ def tool( ) # If someone doesn't want a schema applied, we must treat it as # a simple string->string function - if func.__doc__ is None: + if dec_func.__doc__ is None: raise ValueError( "Function must have a docstring if " "description not provided and infer_schema is False." @@ -1447,3 +1480,17 @@ def convert_runnable_to_tool( description=description, args_schema=args_schema, ) + + +def _get_runnable_config_param(func: Callable) -> Optional[str]: + if isinstance(func, functools.partial): + func = func.func + try: + type_hints = get_type_hints(func) + except Exception: + return None + else: + for name, type_ in type_hints.items(): + if type_ is RunnableConfig: + return name + return None diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 478707ea555..e83a35d35e8 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1,6 +1,5 @@ """Test the base tool implementation.""" -import asyncio import inspect import json import sys @@ -19,7 +18,12 @@ from langchain_core.callbacks import ( ) from langchain_core.messages import ToolMessage from langchain_core.pydantic_v1 import BaseModel, ValidationError -from langchain_core.runnables import Runnable, RunnableLambda, ensure_config +from langchain_core.runnables import ( + Runnable, + RunnableConfig, + RunnableLambda, + ensure_config, +) from langchain_core.tools import ( BaseTool, SchemaAnnotationError, @@ -914,7 +918,6 @@ async def test_async_tool_pass_context() -> None: @tool async def foo(bar: str) -> str: """The foo.""" - await asyncio.sleep(0.0001) config = ensure_config() assert config["configurable"]["foo"] == "not-bar" assert bar == "baz" @@ -925,6 +928,64 @@ async def test_async_tool_pass_context() -> None: ) +def assert_bar(bar: Any, bar_config: RunnableConfig) -> Any: + assert bar_config["configurable"]["foo"] == "not-bar" + assert bar == "baz" + return bar + + +@tool +def foo(bar: Any, bar_config: RunnableConfig) -> Any: + """The foo.""" + return assert_bar(bar, bar_config) + + +@tool +async def afoo(bar: Any, bar_config: RunnableConfig) -> Any: + """The foo.""" + return assert_bar(bar, bar_config) + + +@tool(infer_schema=False) +def simple_foo(bar: Any, bar_config: RunnableConfig) -> Any: + """The foo.""" + return assert_bar(bar, bar_config) + + +@tool(infer_schema=False) +async def asimple_foo(bar: Any, bar_config: RunnableConfig) -> Any: + """The foo.""" + return assert_bar(bar, bar_config) + + +class FooBase(BaseTool): + name: str = "Foo" + description: str = "Foo" + + def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any: + return assert_bar(bar, bar_config) + + +class AFooBase(FooBase): + async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any: + return assert_bar(bar, bar_config) + + +@pytest.mark.parametrize("tool", [foo, simple_foo, FooBase(), AFooBase()]) +def test_tool_pass_config(tool: BaseTool) -> None: + assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" + + +@pytest.mark.parametrize( + "tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()] +) +async def test_async_tool_pass_config(tool: BaseTool) -> None: + assert ( + await tool.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) + == "baz" + ) + + def test_tool_description() -> None: def foo(bar: str) -> str: """The foo."""