From ca768c8353c5404d570d114e98d544e1ffadba76 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Wed, 15 May 2024 10:49:49 -0700 Subject: [PATCH] [Core] Check is async callable (#21714) To permit proper coercion of objects like the following: ```python class MyAsyncCallable: async def __call__(self, foo): return await ... class MyAsyncGenerator: async def __call__(self, foo): await ... yield ``` --- libs/core/langchain_core/runnables/base.py | 14 +++++----- libs/core/langchain_core/runnables/utils.py | 27 +++++++++++++++++++ .../unit_tests/runnables/test_runnable.py | 17 ++++++++++++ 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index c2dfa420183..2ac53cdccd6 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -76,6 +76,8 @@ from langchain_core.runnables.utils import ( get_lambda_source, get_unique_config_specs, indent_lines_after_first, + is_async_callable, + is_async_generator, ) from langchain_core.utils.aiter import atee, py_anext from langchain_core.utils.iter import safetee @@ -3300,7 +3302,7 @@ class RunnableGenerator(Runnable[Input, Output]): self._atransform = atransform func_for_name: Callable = atransform - if inspect.isasyncgenfunction(transform): + if is_async_generator(transform): self._atransform = transform # type: ignore[assignment] func_for_name = transform elif inspect.isgeneratorfunction(transform): @@ -3513,7 +3515,7 @@ class RunnableLambda(Runnable[Input, Output]): self.afunc = afunc func_for_name: Callable = afunc - if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): + if is_async_callable(func) or is_async_generator(func): if afunc is not None: raise TypeError( "Func was provided as a coroutine function, but afunc was " @@ -3774,7 +3776,7 @@ class RunnableLambda(Runnable[Input, Output]): afunc = f - if inspect.isasyncgenfunction(afunc): + if is_async_generator(afunc): output: Optional[Output] = None async for chunk in cast( AsyncIterator[Output], @@ -3992,7 +3994,7 @@ class RunnableLambda(Runnable[Input, Output]): afunc = f - if inspect.isasyncgenfunction(afunc): + if is_async_generator(afunc): output: Optional[Output] = None async for chunk in cast( AsyncIterator[Output], @@ -4034,7 +4036,7 @@ class RunnableLambda(Runnable[Input, Output]): ), ): yield chunk - elif not inspect.isasyncgenfunction(afunc): + elif not is_async_generator(afunc): # Otherwise, just yield it yield cast(Output, output) @@ -4836,7 +4838,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: """ if isinstance(thing, Runnable): return thing - elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing): + elif is_async_generator(thing) or inspect.isgeneratorfunction(thing): return RunnableGenerator(thing) elif callable(thing): return RunnableLambda(cast(Callable[[Input], Output], thing)) diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index f77f756e666..3214ca66639 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -1,4 +1,5 @@ """Utility code for runnables.""" + from __future__ import annotations import ast @@ -11,6 +12,8 @@ from itertools import groupby from typing import ( Any, AsyncIterable, + AsyncIterator, + Awaitable, Callable, Coroutine, Dict, @@ -27,6 +30,8 @@ from typing import ( Union, ) +from typing_extensions import TypeGuard + from langchain_core.pydantic_v1 import BaseConfig, BaseModel from langchain_core.pydantic_v1 import create_model as _create_model_base from langchain_core.runnables.schema import StreamEvent @@ -533,3 +538,25 @@ def _create_model_cached( return _create_model_base( __model_name, __config__=_SchemaConfig, **field_definitions ) + + +def is_async_generator( + func: Any, +) -> TypeGuard[Callable[..., AsyncIterator]]: + """Check if a function is an async generator.""" + return ( + inspect.isasyncgenfunction(func) + or hasattr(func, "__call__") + and inspect.isasyncgenfunction(func.__call__) + ) + + +def is_async_callable( + func: Any, +) -> TypeGuard[Callable[..., Awaitable]]: + """Check if a function is async.""" + return ( + asyncio.iscoroutinefunction(func) + or hasattr(func, "__call__") + and asyncio.iscoroutinefunction(func.__call__) + ) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 37e6bff2e3f..dd643a5e54c 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -4883,6 +4883,23 @@ async def test_runnable_gen() -> None: assert [p async for p in arunnable.astream(None)] == [1, 2, 3] assert await arunnable.abatch([None, None]) == [6, 6] + class AsyncGen: + async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]: + yield 1 + yield 2 + yield 3 + + arunnablecallable = RunnableGenerator(AsyncGen()) + assert await arunnablecallable.ainvoke(None) == 6 + assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3] + assert await arunnablecallable.abatch([None, None]) == [6, 6] + with pytest.raises(NotImplementedError): + arunnablecallable.invoke(None) + with pytest.raises(NotImplementedError): + arunnablecallable.stream(None) + with pytest.raises(NotImplementedError): + arunnablecallable.batch([None, None]) + async def test_runnable_gen_context_config() -> None: """Test that a generator can call other runnables with config