[Core] Feat: update config CVar in tool.invoke (#20808)

This commit is contained in:
William FH 2024-04-24 17:17:21 -07:00 committed by GitHub
parent 2cd907ad7e
commit a936f696a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 7 deletions

View File

@ -19,10 +19,12 @@ tool for the job.
from __future__ import annotations from __future__ import annotations
import asyncio
import inspect import inspect
import uuid import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextvars import copy_context
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
@ -60,7 +62,12 @@ from langchain_core.runnables import (
RunnableSerializable, RunnableSerializable,
ensure_config, ensure_config,
) )
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import (
patch_config,
run_in_executor,
var_child_runnable_config,
)
from langchain_core.runnables.utils import accepts_context
class SchemaAnnotationError(TypeError): class SchemaAnnotationError(TypeError):
@ -255,6 +262,7 @@ class ChildTool(BaseTool):
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
config=config,
**kwargs, **kwargs,
) )
@ -272,6 +280,7 @@ class ChildTool(BaseTool):
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
config=config,
**kwargs, **kwargs,
) )
@ -353,6 +362,7 @@ class ChildTool(BaseTool):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None, run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool.""" """Run the tool."""
@ -385,12 +395,20 @@ class ChildTool(BaseTool):
**kwargs, **kwargs,
) )
try: try:
child_config = patch_config(
config,
callbacks=run_manager.get_child(),
)
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
parsed_input = self._parse_input(tool_input) parsed_input = self._parse_input(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = ( observation = (
self._run(*tool_args, run_manager=run_manager, **tool_kwargs) context.run(
self._run, *tool_args, run_manager=run_manager, **tool_kwargs
)
if new_arg_supported if new_arg_supported
else self._run(*tool_args, **tool_kwargs) else context.run(self._run, *tool_args, **tool_kwargs)
) )
except ValidationError as e: except ValidationError as e:
if not self.handle_validation_error: if not self.handle_validation_error:
@ -446,6 +464,7 @@ class ChildTool(BaseTool):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None, run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
@ -476,11 +495,24 @@ class ChildTool(BaseTool):
parsed_input = self._parse_input(tool_input) parsed_input = self._parse_input(tool_input)
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = ( child_config = patch_config(
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) config,
if new_arg_supported callbacks=run_manager.get_child(),
else await self._arun(*tool_args, **tool_kwargs)
) )
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
coro = (
context.run(
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs
)
if new_arg_supported
else context.run(self._arun, *tool_args, **tool_kwargs)
)
if accepts_context(asyncio.create_task):
observation = await asyncio.create_task(coro, context=context) # type: ignore
else:
observation = await coro
except ValidationError as e: except ValidationError as e:
if not self.handle_validation_error: if not self.handle_validation_error:
raise e raise e

View File

@ -1,6 +1,8 @@
"""Test the base tool implementation.""" """Test the base tool implementation."""
import asyncio
import json import json
import sys
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from functools import partial from functools import partial
@ -13,6 +15,7 @@ from langchain_core.callbacks import (
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.pydantic_v1 import BaseModel, ValidationError from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import ensure_config
from langchain_core.tools import ( from langchain_core.tools import (
BaseTool, BaseTool,
SchemaAnnotationError, SchemaAnnotationError,
@ -871,3 +874,34 @@ def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> No
else: else:
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
foo.invoke(inputs) # type: ignore foo.invoke(inputs) # type: ignore
def test_tool_pass_context() -> None:
@tool
def foo(bar: str) -> str:
"""The foo."""
config = ensure_config()
assert config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
return bar
assert foo.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="requires python3.11 or higher",
)
async def test_async_tool_pass_context() -> None:
@tool
async def foo(bar: str) -> str:
"""The foo."""
await asyncio.sleep(0.0001)
config = ensure_config()
assert config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
return bar
assert (
await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
)