From bd4b68cd541719e21a81070a394f724d3600b4a2 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 17 Jun 2024 13:40:01 -0700 Subject: [PATCH] core: run_in_executor: Wrap StopIteration in RuntimeError (#22997) - StopIteration can't be set on an asyncio.Future it raises a TypeError and leaves the Future pending forever so we need to convert it to a RuntimeError --- libs/core/langchain_core/runnables/config.py | 16 ++++++++++++---- .../tests/unit_tests/runnables/test_config.py | 19 ++++++++++++++++++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 15f2006aab6..4300da87725 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -542,13 +542,21 @@ async def run_in_executor( Returns: Output: The output of the function. """ + + def wrapper() -> T: + try: + return func(*args, **kwargs) + except StopIteration as exc: + # StopIteration can't be set on an asyncio.Future + # it raises a TypeError and leaves the Future pending forever + # so we need to convert it to a RuntimeError + raise RuntimeError from exc + if executor_or_config is None or isinstance(executor_or_config, dict): # Use default executor with context copied from current context return await asyncio.get_running_loop().run_in_executor( None, - cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)), + cast(Callable[..., T], partial(copy_context().run, wrapper)), ) - return await asyncio.get_running_loop().run_in_executor( - executor_or_config, partial(func, **kwargs), *args - ) + return await asyncio.get_running_loop().run_in_executor(executor_or_config, wrapper) diff --git a/libs/core/tests/unit_tests/runnables/test_config.py b/libs/core/tests/unit_tests/runnables/test_config.py index 5ea4a4c58ec..6e68b40484d 100644 --- a/libs/core/tests/unit_tests/runnables/test_config.py +++ b/libs/core/tests/unit_tests/runnables/test_config.py @@ -1,10 +1,16 @@ from typing import Any, cast +import pytest + from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain_core.runnables import RunnableBinding, RunnablePassthrough -from langchain_core.runnables.config import RunnableConfig, merge_configs +from langchain_core.runnables.config import ( + RunnableConfig, + merge_configs, + run_in_executor, +) from langchain_core.tracers.stdout import ConsoleCallbackHandler @@ -43,3 +49,14 @@ def test_config_arbitrary_keys() -> None: config = cast(RunnableBinding, bound).config assert config.get("my_custom_key") == "my custom value" + + +async def test_run_in_executor() -> None: + def raises_stop_iter() -> Any: + return next(iter([])) + + with pytest.raises(StopIteration): + raises_stop_iter() + + with pytest.raises(RuntimeError): + await run_in_executor(None, raises_stop_iter)