[Core] Feat: update config CVar in tool.invoke (#20808)

This commit is contained in:
William FH 2024-04-24 17:17:21 -07:00 committed by GitHub
parent 2cd907ad7e
commit a936f696a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 7 deletions

View File

@ -19,10 +19,12 @@ tool for the job.
from __future__ import annotations
import asyncio
import inspect
import uuid
import warnings
from abc import ABC, abstractmethod
from contextvars import copy_context
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
@ -60,7 +62,12 @@ from langchain_core.runnables import (
RunnableSerializable,
ensure_config,
)
from langchain_core.runnables.config import run_in_executor
from langchain_core.runnables.config import (
patch_config,
run_in_executor,
var_child_runnable_config,
)
from langchain_core.runnables.utils import accepts_context
class SchemaAnnotationError(TypeError):
@ -255,6 +262,7 @@ class ChildTool(BaseTool):
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
config=config,
**kwargs,
)
@ -272,6 +280,7 @@ class ChildTool(BaseTool):
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
config=config,
**kwargs,
)
@ -353,6 +362,7 @@ class ChildTool(BaseTool):
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
@ -385,12 +395,20 @@ class ChildTool(BaseTool):
**kwargs,
)
try:
child_config = patch_config(
config,
callbacks=run_manager.get_child(),
)
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
parsed_input = self._parse_input(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
context.run(
self._run, *tool_args, run_manager=run_manager, **tool_kwargs
)
if new_arg_supported
else self._run(*tool_args, **tool_kwargs)
else context.run(self._run, *tool_args, **tool_kwargs)
)
except ValidationError as e:
if not self.handle_validation_error:
@ -446,6 +464,7 @@ class ChildTool(BaseTool):
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
@ -476,11 +495,24 @@ class ChildTool(BaseTool):
parsed_input = self._parse_input(tool_input)
# We then call the tool on the tool input to get an observation
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported
else await self._arun(*tool_args, **tool_kwargs)
child_config = patch_config(
config,
callbacks=run_manager.get_child(),
)
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
coro = (
context.run(
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs
)
if new_arg_supported
else context.run(self._arun, *tool_args, **tool_kwargs)
)
if accepts_context(asyncio.create_task):
observation = await asyncio.create_task(coro, context=context) # type: ignore
else:
observation = await coro
except ValidationError as e:
if not self.handle_validation_error:
raise e

View File

@ -1,6 +1,8 @@
"""Test the base tool implementation."""
import asyncio
import json
import sys
from datetime import datetime
from enum import Enum
from functools import partial
@ -13,6 +15,7 @@ from langchain_core.callbacks import (
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import ensure_config
from langchain_core.tools import (
BaseTool,
SchemaAnnotationError,
@ -871,3 +874,34 @@ def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> No
else:
with pytest.raises(ValidationError):
foo.invoke(inputs) # type: ignore
def test_tool_pass_context() -> None:
@tool
def foo(bar: str) -> str:
"""The foo."""
config = ensure_config()
assert config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
return bar
assert foo.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="requires python3.11 or higher",
)
async def test_async_tool_pass_context() -> None:
@tool
async def foo(bar: str) -> str:
"""The foo."""
await asyncio.sleep(0.0001)
config = ensure_config()
assert config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
return bar
assert (
await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
)