Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
832f4e926c REturn exceptions 2023-08-06 15:39:32 -07:00

View File

@@ -35,13 +35,17 @@ async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
return await coro
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
async def _gather_with_concurrency(
n: Union[int, None], *coros: Coroutine, return_exceptions: bool = False
) -> list:
if n is None:
return await asyncio.gather(*coros)
return await asyncio.gather(*coros, return_exceptions=return_exceptions)
semaphore = asyncio.Semaphore(n)
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
return await asyncio.gather(
*(_gated_coro(semaphore, c) for c in coros), return_exceptions=return_exceptions
)
class RunnableConfig(TypedDict, total=False):
@@ -108,6 +112,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
return_exceptions: bool = False,
) -> List[Output]:
configs = self._get_config_list(config, len(inputs))
@@ -115,8 +120,18 @@ class Runnable(Generic[Input, Output], ABC):
if len(inputs) == 1:
return [self.invoke(inputs[0], configs[0])]
def handle_exceptions(
input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
try:
return self.invoke(input, config)
except Exception as e:
if return_exceptions:
return e
raise e
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(executor.map(self.invoke, inputs, configs))
return list(executor.map(handle_exceptions, inputs, configs))
async def abatch(
self,
@@ -124,11 +139,14 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
return_exceptions: bool = False,
) -> List[Output]:
configs = self._get_config_list(config, len(inputs))
coros = map(self.ainvoke, inputs, configs)
return await _gather_with_concurrency(max_concurrency, *coros)
return await _gather_with_concurrency(
max_concurrency, *coros, return_exceptions=return_exceptions
)
def stream(
self, input: Input, config: Optional[RunnableConfig] = None