mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
[Core] Check is async callable (#21714)
To permit proper coercion of objects like the following: ```python class MyAsyncCallable: async def __call__(self, foo): return await ... class MyAsyncGenerator: async def __call__(self, foo): await ... yield ```
This commit is contained in:
parent
7128c2d8ad
commit
ca768c8353
@ -76,6 +76,8 @@ from langchain_core.runnables.utils import (
|
||||
get_lambda_source,
|
||||
get_unique_config_specs,
|
||||
indent_lines_after_first,
|
||||
is_async_callable,
|
||||
is_async_generator,
|
||||
)
|
||||
from langchain_core.utils.aiter import atee, py_anext
|
||||
from langchain_core.utils.iter import safetee
|
||||
@ -3300,7 +3302,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
self._atransform = atransform
|
||||
func_for_name: Callable = atransform
|
||||
|
||||
if inspect.isasyncgenfunction(transform):
|
||||
if is_async_generator(transform):
|
||||
self._atransform = transform # type: ignore[assignment]
|
||||
func_for_name = transform
|
||||
elif inspect.isgeneratorfunction(transform):
|
||||
@ -3513,7 +3515,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
self.afunc = afunc
|
||||
func_for_name: Callable = afunc
|
||||
|
||||
if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
|
||||
if is_async_callable(func) or is_async_generator(func):
|
||||
if afunc is not None:
|
||||
raise TypeError(
|
||||
"Func was provided as a coroutine function, but afunc was "
|
||||
@ -3774,7 +3776,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
afunc = f
|
||||
|
||||
if inspect.isasyncgenfunction(afunc):
|
||||
if is_async_generator(afunc):
|
||||
output: Optional[Output] = None
|
||||
async for chunk in cast(
|
||||
AsyncIterator[Output],
|
||||
@ -3992,7 +3994,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
afunc = f
|
||||
|
||||
if inspect.isasyncgenfunction(afunc):
|
||||
if is_async_generator(afunc):
|
||||
output: Optional[Output] = None
|
||||
async for chunk in cast(
|
||||
AsyncIterator[Output],
|
||||
@ -4034,7 +4036,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
),
|
||||
):
|
||||
yield chunk
|
||||
elif not inspect.isasyncgenfunction(afunc):
|
||||
elif not is_async_generator(afunc):
|
||||
# Otherwise, just yield it
|
||||
yield cast(Output, output)
|
||||
|
||||
@ -4836,7 +4838,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
|
||||
"""
|
||||
if isinstance(thing, Runnable):
|
||||
return thing
|
||||
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
|
||||
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
|
||||
return RunnableGenerator(thing)
|
||||
elif callable(thing):
|
||||
return RunnableLambda(cast(Callable[[Input], Output], thing))
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Utility code for runnables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
@ -11,6 +12,8 @@ from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
@ -27,6 +30,8 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseConfig, BaseModel
|
||||
from langchain_core.pydantic_v1 import create_model as _create_model_base
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
@ -533,3 +538,25 @@ def _create_model_cached(
|
||||
return _create_model_base(
|
||||
__model_name, __config__=_SchemaConfig, **field_definitions
|
||||
)
|
||||
|
||||
|
||||
def is_async_generator(
|
||||
func: Any,
|
||||
) -> TypeGuard[Callable[..., AsyncIterator]]:
|
||||
"""Check if a function is an async generator."""
|
||||
return (
|
||||
inspect.isasyncgenfunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
and inspect.isasyncgenfunction(func.__call__)
|
||||
)
|
||||
|
||||
|
||||
def is_async_callable(
|
||||
func: Any,
|
||||
) -> TypeGuard[Callable[..., Awaitable]]:
|
||||
"""Check if a function is async."""
|
||||
return (
|
||||
asyncio.iscoroutinefunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
and asyncio.iscoroutinefunction(func.__call__)
|
||||
)
|
||||
|
@ -4883,6 +4883,23 @@ async def test_runnable_gen() -> None:
|
||||
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
|
||||
assert await arunnable.abatch([None, None]) == [6, 6]
|
||||
|
||||
class AsyncGen:
|
||||
async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
|
||||
arunnablecallable = RunnableGenerator(AsyncGen())
|
||||
assert await arunnablecallable.ainvoke(None) == 6
|
||||
assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3]
|
||||
assert await arunnablecallable.abatch([None, None]) == [6, 6]
|
||||
with pytest.raises(NotImplementedError):
|
||||
arunnablecallable.invoke(None)
|
||||
with pytest.raises(NotImplementedError):
|
||||
arunnablecallable.stream(None)
|
||||
with pytest.raises(NotImplementedError):
|
||||
arunnablecallable.batch([None, None])
|
||||
|
||||
|
||||
async def test_runnable_gen_context_config() -> None:
|
||||
"""Test that a generator can call other runnables with config
|
||||
|
Loading…
Reference in New Issue
Block a user