mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
[Core] Feat: update config CVar in tool.invoke (#20808)
This commit is contained in:
parent
2cd907ad7e
commit
a936f696a6
@ -19,10 +19,12 @@ tool for the job.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextvars import copy_context
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
@ -60,7 +62,12 @@ from langchain_core.runnables import (
|
||||
RunnableSerializable,
|
||||
ensure_config,
|
||||
)
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.runnables.config import (
|
||||
patch_config,
|
||||
run_in_executor,
|
||||
var_child_runnable_config,
|
||||
)
|
||||
from langchain_core.runnables.utils import accepts_context
|
||||
|
||||
|
||||
class SchemaAnnotationError(TypeError):
|
||||
@ -255,6 +262,7 @@ class ChildTool(BaseTool):
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
config=config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -272,6 +280,7 @@ class ChildTool(BaseTool):
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
config=config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -353,6 +362,7 @@ class ChildTool(BaseTool):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
@ -385,12 +395,20 @@ class ChildTool(BaseTool):
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
child_config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
)
|
||||
context = copy_context()
|
||||
context.run(var_child_runnable_config.set, child_config)
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||
observation = (
|
||||
self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
|
||||
context.run(
|
||||
self._run, *tool_args, run_manager=run_manager, **tool_kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._run(*tool_args, **tool_kwargs)
|
||||
else context.run(self._run, *tool_args, **tool_kwargs)
|
||||
)
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
@ -446,6 +464,7 @@ class ChildTool(BaseTool):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
@ -476,11 +495,24 @@ class ChildTool(BaseTool):
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
# We then call the tool on the tool input to get an observation
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||
observation = (
|
||||
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs)
|
||||
if new_arg_supported
|
||||
else await self._arun(*tool_args, **tool_kwargs)
|
||||
child_config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
)
|
||||
context = copy_context()
|
||||
context.run(var_child_runnable_config.set, child_config)
|
||||
coro = (
|
||||
context.run(
|
||||
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else context.run(self._arun, *tool_args, **tool_kwargs)
|
||||
)
|
||||
if accepts_context(asyncio.create_task):
|
||||
observation = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
else:
|
||||
observation = await coro
|
||||
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
raise e
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Test the base tool implementation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
@ -13,6 +15,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
from langchain_core.runnables import ensure_config
|
||||
from langchain_core.tools import (
|
||||
BaseTool,
|
||||
SchemaAnnotationError,
|
||||
@ -871,3 +874,34 @@ def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> No
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
foo.invoke(inputs) # type: ignore
|
||||
|
||||
|
||||
def test_tool_pass_context() -> None:
|
||||
@tool
|
||||
def foo(bar: str) -> str:
|
||||
"""The foo."""
|
||||
config = ensure_config()
|
||||
assert config["configurable"]["foo"] == "not-bar"
|
||||
assert bar == "baz"
|
||||
return bar
|
||||
|
||||
assert foo.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 11),
|
||||
reason="requires python3.11 or higher",
|
||||
)
|
||||
async def test_async_tool_pass_context() -> None:
|
||||
@tool
|
||||
async def foo(bar: str) -> str:
|
||||
"""The foo."""
|
||||
await asyncio.sleep(0.0001)
|
||||
config = ensure_config()
|
||||
assert config["configurable"]["foo"] == "not-bar"
|
||||
assert bar == "baz"
|
||||
return bar
|
||||
|
||||
assert (
|
||||
await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user