mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-23 04:53:09 +00:00
core: run_in_executor: Wrap StopIteration in RuntimeError (#22997)
- StopIteration can't be set on an asyncio.Future it raises a TypeError and leaves the Future pending forever so we need to convert it to a RuntimeError
This commit is contained in:
parent
d96f67b06f
commit
bd4b68cd54
@ -542,13 +542,21 @@ async def run_in_executor(
|
|||||||
Returns:
|
Returns:
|
||||||
Output: The output of the function.
|
Output: The output of the function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def wrapper() -> T:
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except StopIteration as exc:
|
||||||
|
# StopIteration can't be set on an asyncio.Future
|
||||||
|
# it raises a TypeError and leaves the Future pending forever
|
||||||
|
# so we need to convert it to a RuntimeError
|
||||||
|
raise RuntimeError from exc
|
||||||
|
|
||||||
if executor_or_config is None or isinstance(executor_or_config, dict):
|
if executor_or_config is None or isinstance(executor_or_config, dict):
|
||||||
# Use default executor with context copied from current context
|
# Use default executor with context copied from current context
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)),
|
cast(Callable[..., T], partial(copy_context().run, wrapper)),
|
||||||
)
|
)
|
||||||
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await asyncio.get_running_loop().run_in_executor(executor_or_config, wrapper)
|
||||||
executor_or_config, partial(func, **kwargs), *args
|
|
||||||
)
|
|
||||||
|
@ -1,10 +1,16 @@
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import CallbackManager
|
from langchain_core.callbacks.manager import CallbackManager
|
||||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
|
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
|
||||||
from langchain_core.runnables.config import RunnableConfig, merge_configs
|
from langchain_core.runnables.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
merge_configs,
|
||||||
|
run_in_executor,
|
||||||
|
)
|
||||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
@ -43,3 +49,14 @@ def test_config_arbitrary_keys() -> None:
|
|||||||
config = cast(RunnableBinding, bound).config
|
config = cast(RunnableBinding, bound).config
|
||||||
|
|
||||||
assert config.get("my_custom_key") == "my custom value"
|
assert config.get("my_custom_key") == "my custom value"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_run_in_executor() -> None:
|
||||||
|
def raises_stop_iter() -> Any:
|
||||||
|
return next(iter([]))
|
||||||
|
|
||||||
|
with pytest.raises(StopIteration):
|
||||||
|
raises_stop_iter()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await run_in_executor(None, raises_stop_iter)
|
||||||
|
Loading…
Reference in New Issue
Block a user