mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
[Core] Check is async callable (#21714)
To permit proper coercion of objects like the following: ```python class MyAsyncCallable: async def __call__(self, foo): return await ... class MyAsyncGenerator: async def __call__(self, foo): await ... yield ```
This commit is contained in:
parent
7128c2d8ad
commit
ca768c8353
@ -76,6 +76,8 @@ from langchain_core.runnables.utils import (
|
|||||||
get_lambda_source,
|
get_lambda_source,
|
||||||
get_unique_config_specs,
|
get_unique_config_specs,
|
||||||
indent_lines_after_first,
|
indent_lines_after_first,
|
||||||
|
is_async_callable,
|
||||||
|
is_async_generator,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.aiter import atee, py_anext
|
from langchain_core.utils.aiter import atee, py_anext
|
||||||
from langchain_core.utils.iter import safetee
|
from langchain_core.utils.iter import safetee
|
||||||
@ -3300,7 +3302,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
self._atransform = atransform
|
self._atransform = atransform
|
||||||
func_for_name: Callable = atransform
|
func_for_name: Callable = atransform
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(transform):
|
if is_async_generator(transform):
|
||||||
self._atransform = transform # type: ignore[assignment]
|
self._atransform = transform # type: ignore[assignment]
|
||||||
func_for_name = transform
|
func_for_name = transform
|
||||||
elif inspect.isgeneratorfunction(transform):
|
elif inspect.isgeneratorfunction(transform):
|
||||||
@ -3513,7 +3515,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
self.afunc = afunc
|
self.afunc = afunc
|
||||||
func_for_name: Callable = afunc
|
func_for_name: Callable = afunc
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
|
if is_async_callable(func) or is_async_generator(func):
|
||||||
if afunc is not None:
|
if afunc is not None:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Func was provided as a coroutine function, but afunc was "
|
"Func was provided as a coroutine function, but afunc was "
|
||||||
@ -3774,7 +3776,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
afunc = f
|
afunc = f
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(afunc):
|
if is_async_generator(afunc):
|
||||||
output: Optional[Output] = None
|
output: Optional[Output] = None
|
||||||
async for chunk in cast(
|
async for chunk in cast(
|
||||||
AsyncIterator[Output],
|
AsyncIterator[Output],
|
||||||
@ -3992,7 +3994,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
afunc = f
|
afunc = f
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(afunc):
|
if is_async_generator(afunc):
|
||||||
output: Optional[Output] = None
|
output: Optional[Output] = None
|
||||||
async for chunk in cast(
|
async for chunk in cast(
|
||||||
AsyncIterator[Output],
|
AsyncIterator[Output],
|
||||||
@ -4034,7 +4036,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
elif not inspect.isasyncgenfunction(afunc):
|
elif not is_async_generator(afunc):
|
||||||
# Otherwise, just yield it
|
# Otherwise, just yield it
|
||||||
yield cast(Output, output)
|
yield cast(Output, output)
|
||||||
|
|
||||||
@ -4836,7 +4838,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
|
|||||||
"""
|
"""
|
||||||
if isinstance(thing, Runnable):
|
if isinstance(thing, Runnable):
|
||||||
return thing
|
return thing
|
||||||
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
|
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
|
||||||
return RunnableGenerator(thing)
|
return RunnableGenerator(thing)
|
||||||
elif callable(thing):
|
elif callable(thing):
|
||||||
return RunnableLambda(cast(Callable[[Input], Output], thing))
|
return RunnableLambda(cast(Callable[[Input], Output], thing))
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Utility code for runnables."""
|
"""Utility code for runnables."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
@ -11,6 +12,8 @@ from itertools import groupby
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterable,
|
AsyncIterable,
|
||||||
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
@ -27,6 +30,8 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import TypeGuard
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseConfig, BaseModel
|
from langchain_core.pydantic_v1 import BaseConfig, BaseModel
|
||||||
from langchain_core.pydantic_v1 import create_model as _create_model_base
|
from langchain_core.pydantic_v1 import create_model as _create_model_base
|
||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
@ -533,3 +538,25 @@ def _create_model_cached(
|
|||||||
return _create_model_base(
|
return _create_model_base(
|
||||||
__model_name, __config__=_SchemaConfig, **field_definitions
|
__model_name, __config__=_SchemaConfig, **field_definitions
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_async_generator(
|
||||||
|
func: Any,
|
||||||
|
) -> TypeGuard[Callable[..., AsyncIterator]]:
|
||||||
|
"""Check if a function is an async generator."""
|
||||||
|
return (
|
||||||
|
inspect.isasyncgenfunction(func)
|
||||||
|
or hasattr(func, "__call__")
|
||||||
|
and inspect.isasyncgenfunction(func.__call__)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_async_callable(
|
||||||
|
func: Any,
|
||||||
|
) -> TypeGuard[Callable[..., Awaitable]]:
|
||||||
|
"""Check if a function is async."""
|
||||||
|
return (
|
||||||
|
asyncio.iscoroutinefunction(func)
|
||||||
|
or hasattr(func, "__call__")
|
||||||
|
and asyncio.iscoroutinefunction(func.__call__)
|
||||||
|
)
|
||||||
|
@ -4883,6 +4883,23 @@ async def test_runnable_gen() -> None:
|
|||||||
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
|
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
|
||||||
assert await arunnable.abatch([None, None]) == [6, 6]
|
assert await arunnable.abatch([None, None]) == [6, 6]
|
||||||
|
|
||||||
|
class AsyncGen:
|
||||||
|
async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||||
|
yield 1
|
||||||
|
yield 2
|
||||||
|
yield 3
|
||||||
|
|
||||||
|
arunnablecallable = RunnableGenerator(AsyncGen())
|
||||||
|
assert await arunnablecallable.ainvoke(None) == 6
|
||||||
|
assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3]
|
||||||
|
assert await arunnablecallable.abatch([None, None]) == [6, 6]
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
arunnablecallable.invoke(None)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
arunnablecallable.stream(None)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
arunnablecallable.batch([None, None])
|
||||||
|
|
||||||
|
|
||||||
async def test_runnable_gen_context_config() -> None:
|
async def test_runnable_gen_context_config() -> None:
|
||||||
"""Test that a generator can call other runnables with config
|
"""Test that a generator can call other runnables with config
|
||||||
|
Loading…
Reference in New Issue
Block a user