mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 11:02:37 +00:00
[Core] Check is async callable (#21714)
To permit proper coercion of objects like the following: ```python class MyAsyncCallable: async def __call__(self, foo): return await ... class MyAsyncGenerator: async def __call__(self, foo): await ... yield ```
This commit is contained in:
@@ -4883,6 +4883,23 @@ async def test_runnable_gen() -> None:
|
||||
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
|
||||
assert await arunnable.abatch([None, None]) == [6, 6]
|
||||
|
||||
class AsyncGen:
|
||||
async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
|
||||
arunnablecallable = RunnableGenerator(AsyncGen())
|
||||
assert await arunnablecallable.ainvoke(None) == 6
|
||||
assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3]
|
||||
assert await arunnablecallable.abatch([None, None]) == [6, 6]
|
||||
with pytest.raises(NotImplementedError):
|
||||
arunnablecallable.invoke(None)
|
||||
with pytest.raises(NotImplementedError):
|
||||
arunnablecallable.stream(None)
|
||||
with pytest.raises(NotImplementedError):
|
||||
arunnablecallable.batch([None, None])
|
||||
|
||||
|
||||
async def test_runnable_gen_context_config() -> None:
|
||||
"""Test that a generator can call other runnables with config
|
||||
|
Reference in New Issue
Block a user