mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-22 20:43:08 +00:00
core: run_in_executor: Wrap StopIteration in RuntimeError (#22997)
- StopIteration can't be set on an asyncio.Future it raises a TypeError and leaves the Future pending forever so we need to convert it to a RuntimeError
This commit is contained in:
parent
d96f67b06f
commit
bd4b68cd54
@ -542,13 +542,21 @@ async def run_in_executor(
|
||||
Returns:
|
||||
Output: The output of the function.
|
||||
"""
|
||||
|
||||
def wrapper() -> T:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except StopIteration as exc:
|
||||
# StopIteration can't be set on an asyncio.Future
|
||||
# it raises a TypeError and leaves the Future pending forever
|
||||
# so we need to convert it to a RuntimeError
|
||||
raise RuntimeError from exc
|
||||
|
||||
if executor_or_config is None or isinstance(executor_or_config, dict):
|
||||
# Use default executor with context copied from current context
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)),
|
||||
cast(Callable[..., T], partial(copy_context().run, wrapper)),
|
||||
)
|
||||
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
executor_or_config, partial(func, **kwargs), *args
|
||||
)
|
||||
return await asyncio.get_running_loop().run_in_executor(executor_or_config, wrapper)
|
||||
|
@ -1,10 +1,16 @@
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
|
||||
from langchain_core.runnables.config import RunnableConfig, merge_configs
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
merge_configs,
|
||||
run_in_executor,
|
||||
)
|
||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||
|
||||
|
||||
@ -43,3 +49,14 @@ def test_config_arbitrary_keys() -> None:
|
||||
config = cast(RunnableBinding, bound).config
|
||||
|
||||
assert config.get("my_custom_key") == "my custom value"
|
||||
|
||||
|
||||
async def test_run_in_executor() -> None:
|
||||
def raises_stop_iter() -> Any:
|
||||
return next(iter([]))
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
raises_stop_iter()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await run_in_executor(None, raises_stop_iter)
|
||||
|
Loading…
Reference in New Issue
Block a user