diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 363d12b950f..0e576a39d26 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -19,10 +19,12 @@ tool for the job. from __future__ import annotations +import asyncio import inspect import uuid import warnings from abc import ABC, abstractmethod +from contextvars import copy_context from functools import partial from inspect import signature from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union @@ -60,7 +62,12 @@ from langchain_core.runnables import ( RunnableSerializable, ensure_config, ) -from langchain_core.runnables.config import run_in_executor +from langchain_core.runnables.config import ( + patch_config, + run_in_executor, + var_child_runnable_config, +) +from langchain_core.runnables.utils import accepts_context class SchemaAnnotationError(TypeError): @@ -255,6 +262,7 @@ class ChildTool(BaseTool): metadata=config.get("metadata"), run_name=config.get("run_name"), run_id=config.pop("run_id", None), + config=config, **kwargs, ) @@ -272,6 +280,7 @@ class ChildTool(BaseTool): metadata=config.get("metadata"), run_name=config.get("run_name"), run_id=config.pop("run_id", None), + config=config, **kwargs, ) @@ -353,6 +362,7 @@ class ChildTool(BaseTool): metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, run_id: Optional[uuid.UUID] = None, + config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: """Run the tool.""" @@ -385,12 +395,20 @@ class ChildTool(BaseTool): **kwargs, ) try: + child_config = patch_config( + config, + callbacks=run_manager.get_child(), + ) + context = copy_context() + context.run(var_child_runnable_config.set, child_config) parsed_input = self._parse_input(tool_input) tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) observation = ( - self._run(*tool_args, run_manager=run_manager, **tool_kwargs) + context.run( + self._run, *tool_args, run_manager=run_manager, **tool_kwargs + ) if new_arg_supported - else self._run(*tool_args, **tool_kwargs) + else context.run(self._run, *tool_args, **tool_kwargs) ) except ValidationError as e: if not self.handle_validation_error: @@ -446,6 +464,7 @@ class ChildTool(BaseTool): metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, run_id: Optional[uuid.UUID] = None, + config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: """Run the tool asynchronously.""" @@ -476,11 +495,24 @@ class ChildTool(BaseTool): parsed_input = self._parse_input(tool_input) # We then call the tool on the tool input to get an observation tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) - observation = ( - await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) - if new_arg_supported - else await self._arun(*tool_args, **tool_kwargs) + child_config = patch_config( + config, + callbacks=run_manager.get_child(), ) + context = copy_context() + context.run(var_child_runnable_config.set, child_config) + coro = ( + context.run( + self._arun, *tool_args, run_manager=run_manager, **tool_kwargs + ) + if new_arg_supported + else context.run(self._arun, *tool_args, **tool_kwargs) + ) + if accepts_context(asyncio.create_task): + observation = await asyncio.create_task(coro, context=context) # type: ignore + else: + observation = await coro + except ValidationError as e: if not self.handle_validation_error: raise e diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index ed2484f56a3..df52ae59ca0 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1,6 +1,8 @@ """Test the base tool implementation.""" +import asyncio import json +import sys from datetime import datetime from enum import Enum from functools import partial @@ -13,6 +15,7 @@ from langchain_core.callbacks import ( CallbackManagerForToolRun, ) from langchain_core.pydantic_v1 import BaseModel, ValidationError +from langchain_core.runnables import ensure_config from langchain_core.tools import ( BaseTool, SchemaAnnotationError, @@ -871,3 +874,34 @@ def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> No else: with pytest.raises(ValidationError): foo.invoke(inputs) # type: ignore + + +def test_tool_pass_context() -> None: + @tool + def foo(bar: str) -> str: + """The foo.""" + config = ensure_config() + assert config["configurable"]["foo"] == "not-bar" + assert bar == "baz" + return bar + + assert foo.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore + + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="requires python3.11 or higher", +) +async def test_async_tool_pass_context() -> None: + @tool + async def foo(bar: str) -> str: + """The foo.""" + await asyncio.sleep(0.0001) + config = ensure_config() + assert config["configurable"]["foo"] == "not-bar" + assert bar == "baz" + return bar + + assert ( + await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore + )