Unset context after step (#30378)

While we are already careful to copy before setting the config, if other
objects hold a reference to the config or context, it wouldn't be
cleared.
This commit is contained in:
William FH
2025-03-19 11:46:23 -07:00
committed by GitHub
parent 37190881d3
commit 4130e6476b
5 changed files with 227 additions and 211 deletions

View File

@@ -6,7 +6,6 @@ import inspect
import json
import warnings
from abc import ABC, abstractmethod
from contextvars import copy_context
from inspect import signature
from typing import (
TYPE_CHECKING,
@@ -52,7 +51,7 @@ from langchain_core.runnables import (
patch_config,
run_in_executor,
)
from langchain_core.runnables.config import _set_config_context
from langchain_core.runnables.config import set_config_context
from langchain_core.runnables.utils import asyncio_accepts_context
from langchain_core.utils.function_calling import (
_parse_google_docstring,
@@ -722,14 +721,15 @@ class ChildTool(BaseTool):
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs = tool_kwargs | {"run_manager": run_manager}
if config_param := _get_runnable_config_param(self._run):
tool_kwargs = tool_kwargs | {config_param: config}
response = context.run(self._run, *tool_args, **tool_kwargs)
with set_config_context(child_config) as context:
tool_args, tool_kwargs = self._to_args_and_kwargs(
tool_input, tool_call_id
)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs = tool_kwargs | {"run_manager": run_manager}
if config_param := _get_runnable_config_param(self._run):
tool_kwargs = tool_kwargs | {config_param: config}
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
msg = (
@@ -832,21 +832,20 @@ class ChildTool(BaseTool):
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
with set_config_context(child_config) as context:
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro
coro = self._arun(*tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
msg = (