mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
[Core] Feat: update config CVar in tool.invoke (#20808)
This commit is contained in:
parent
2cd907ad7e
commit
a936f696a6
@ -19,10 +19,12 @@ tool for the job.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from contextvars import copy_context
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||||
@ -60,7 +62,12 @@ from langchain_core.runnables import (
|
|||||||
RunnableSerializable,
|
RunnableSerializable,
|
||||||
ensure_config,
|
ensure_config,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import (
|
||||||
|
patch_config,
|
||||||
|
run_in_executor,
|
||||||
|
var_child_runnable_config,
|
||||||
|
)
|
||||||
|
from langchain_core.runnables.utils import accepts_context
|
||||||
|
|
||||||
|
|
||||||
class SchemaAnnotationError(TypeError):
|
class SchemaAnnotationError(TypeError):
|
||||||
@ -255,6 +262,7 @@ class ChildTool(BaseTool):
|
|||||||
metadata=config.get("metadata"),
|
metadata=config.get("metadata"),
|
||||||
run_name=config.get("run_name"),
|
run_name=config.get("run_name"),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
|
config=config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -272,6 +280,7 @@ class ChildTool(BaseTool):
|
|||||||
metadata=config.get("metadata"),
|
metadata=config.get("metadata"),
|
||||||
run_name=config.get("run_name"),
|
run_name=config.get("run_name"),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
|
config=config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -353,6 +362,7 @@ class ChildTool(BaseTool):
|
|||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
run_id: Optional[uuid.UUID] = None,
|
run_id: Optional[uuid.UUID] = None,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run the tool."""
|
"""Run the tool."""
|
||||||
@ -385,12 +395,20 @@ class ChildTool(BaseTool):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
child_config = patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(),
|
||||||
|
)
|
||||||
|
context = copy_context()
|
||||||
|
context.run(var_child_runnable_config.set, child_config)
|
||||||
parsed_input = self._parse_input(tool_input)
|
parsed_input = self._parse_input(tool_input)
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||||
observation = (
|
observation = (
|
||||||
self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
|
context.run(
|
||||||
|
self._run, *tool_args, run_manager=run_manager, **tool_kwargs
|
||||||
|
)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else self._run(*tool_args, **tool_kwargs)
|
else context.run(self._run, *tool_args, **tool_kwargs)
|
||||||
)
|
)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if not self.handle_validation_error:
|
if not self.handle_validation_error:
|
||||||
@ -446,6 +464,7 @@ class ChildTool(BaseTool):
|
|||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
run_id: Optional[uuid.UUID] = None,
|
run_id: Optional[uuid.UUID] = None,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run the tool asynchronously."""
|
"""Run the tool asynchronously."""
|
||||||
@ -476,11 +495,24 @@ class ChildTool(BaseTool):
|
|||||||
parsed_input = self._parse_input(tool_input)
|
parsed_input = self._parse_input(tool_input)
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||||
observation = (
|
child_config = patch_config(
|
||||||
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs)
|
config,
|
||||||
if new_arg_supported
|
callbacks=run_manager.get_child(),
|
||||||
else await self._arun(*tool_args, **tool_kwargs)
|
|
||||||
)
|
)
|
||||||
|
context = copy_context()
|
||||||
|
context.run(var_child_runnable_config.set, child_config)
|
||||||
|
coro = (
|
||||||
|
context.run(
|
||||||
|
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs
|
||||||
|
)
|
||||||
|
if new_arg_supported
|
||||||
|
else context.run(self._arun, *tool_args, **tool_kwargs)
|
||||||
|
)
|
||||||
|
if accepts_context(asyncio.create_task):
|
||||||
|
observation = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
|
else:
|
||||||
|
observation = await coro
|
||||||
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if not self.handle_validation_error:
|
if not self.handle_validation_error:
|
||||||
raise e
|
raise e
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""Test the base tool implementation."""
|
"""Test the base tool implementation."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -13,6 +15,7 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||||
|
from langchain_core.runnables import ensure_config
|
||||||
from langchain_core.tools import (
|
from langchain_core.tools import (
|
||||||
BaseTool,
|
BaseTool,
|
||||||
SchemaAnnotationError,
|
SchemaAnnotationError,
|
||||||
@ -871,3 +874,34 @@ def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> No
|
|||||||
else:
|
else:
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
foo.invoke(inputs) # type: ignore
|
foo.invoke(inputs) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_pass_context() -> None:
|
||||||
|
@tool
|
||||||
|
def foo(bar: str) -> str:
|
||||||
|
"""The foo."""
|
||||||
|
config = ensure_config()
|
||||||
|
assert config["configurable"]["foo"] == "not-bar"
|
||||||
|
assert bar == "baz"
|
||||||
|
return bar
|
||||||
|
|
||||||
|
assert foo.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 11),
|
||||||
|
reason="requires python3.11 or higher",
|
||||||
|
)
|
||||||
|
async def test_async_tool_pass_context() -> None:
|
||||||
|
@tool
|
||||||
|
async def foo(bar: str) -> str:
|
||||||
|
"""The foo."""
|
||||||
|
await asyncio.sleep(0.0001)
|
||||||
|
config = ensure_config()
|
||||||
|
assert config["configurable"]["foo"] == "not-bar"
|
||||||
|
assert bar == "baz"
|
||||||
|
return bar
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user