diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 9d45d86551f..558495173bc 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -2100,7 +2100,7 @@ class RunnableGenerator(Runnable[Input, Output]): params = inspect.signature(func).parameters first_param = next(iter(params.values()), None) if first_param and first_param.annotation != inspect.Parameter.empty: - return first_param.annotation + return getattr(first_param.annotation, "__args__", (Any,))[0] else: return Any except ValueError: @@ -2112,7 +2112,7 @@ class RunnableGenerator(Runnable[Input, Output]): try: sig = inspect.signature(func) return ( - sig.return_annotation + getattr(sig.return_annotation, "__args__", (Any,))[0] if sig.return_annotation != inspect.Signature.empty else Any ) @@ -2162,7 +2162,7 @@ class RunnableGenerator(Runnable[Input, Output]): final += output return final - async def atransform( + def atransform( self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None, @@ -2175,7 +2175,7 @@ class RunnableGenerator(Runnable[Input, Output]): input, self._atransform, config, **kwargs ) - async def astream( + def astream( self, input: Input, config: Optional[RunnableConfig] = None, diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index b72144102a5..316a9ecad66 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,7 +1,18 @@ import sys from operator import itemgetter -from typing import Any, Dict, List, Optional, Sequence, Union, cast +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Sequence, + Union, + cast, +) from uuid import UUID +from langchain.schema.runnable.base import RunnableGenerator import pytest from freezegun import freeze_time @@ -2809,3 +2820,81 @@ async def test_tool_from_runnable() -> None: "title": "PromptInput", "type": "object", } + + +@pytest.mark.asyncio +async def test_runnable_gen() -> None: + """Test that a generator can be used as a runnable.""" + + def gen(input: Iterator[Any]) -> Iterator[int]: + yield 1 + yield 2 + yield 3 + + runnable = RunnableGenerator(gen) + + assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"} + assert runnable.output_schema.schema() == { + "title": "RunnableGeneratorOutput", + "type": "integer", + } + + assert runnable.invoke(None) == 6 + assert list(runnable.stream(None)) == [1, 2, 3] + assert runnable.batch([None, None]) == [6, 6] + + async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: + yield 1 + yield 2 + yield 3 + + arunnable = RunnableGenerator(agen) + + assert await arunnable.ainvoke(None) == 6 + assert [p async for p in arunnable.astream(None)] == [1, 2, 3] + assert await arunnable.abatch([None, None]) == [6, 6] + + +@pytest.mark.asyncio +async def test_runnable_gen_transform() -> None: + """Test that a generator can be used as a runnable.""" + + def gen_indexes(length_iter: Iterator[int]) -> Iterator[int]: + for i in range(next(length_iter)): + yield i + + async def agen_indexes(length_iter: AsyncIterator[int]) -> AsyncIterator[int]: + async for length in length_iter: + for i in range(length): + yield i + + def plus_one(input: Iterator[int]) -> Iterator[int]: + for i in input: + yield i + 1 + + async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]: + async for i in input: + yield i + 1 + + chain = RunnableGenerator(gen_indexes, agen_indexes) | plus_one + achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one + + assert chain.input_schema.schema() == { + "title": "RunnableGeneratorInput", + "type": "integer", + } + assert chain.output_schema.schema() == { + "title": "RunnableGeneratorOutput", + "type": "integer", + } + assert achain.input_schema.schema() == { + "title": "RunnableGeneratorInput", + "type": "integer", + } + assert achain.output_schema.schema() == { + "title": "RunnableGeneratorOutput", + "type": "integer", + } + + assert list(chain.stream(3)) == [1, 2, 3] + assert [p async for p in achain.astream(4)] == [1, 2, 3, 4]