From 17c5a1621f7a6e3944fbdfb2baf98c8801c9d55b Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 19 May 2025 15:32:31 +0200 Subject: [PATCH] core: Improve Runnable `__or__` method typing annotations (#31273) * It is possible to chain a `Runnable` with an `AsyncIterator` as seen in `test_runnable.py`. * Iterator and AsyncIterator Input/Output of Callables must be put before `Callable[[Other], Any]` otherwise the pattern matching picks the latter. --- libs/core/langchain_core/prompts/structured.py | 8 +++++--- libs/core/langchain_core/runnables/base.py | 12 ++++++++---- .../core/tests/unit_tests/runnables/test_runnable.py | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/libs/core/langchain_core/prompts/structured.py b/libs/core/langchain_core/prompts/structured.py index 9ac04709cea..203c738681d 100644 --- a/libs/core/langchain_core/prompts/structured.py +++ b/libs/core/langchain_core/prompts/structured.py @@ -1,6 +1,6 @@ """Structured prompt template for a language model.""" -from collections.abc import Iterator, Mapping, Sequence +from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from typing import ( Any, Callable, @@ -123,8 +123,9 @@ class StructuredPrompt(ChatPromptTemplate): self, other: Union[ Runnable[Any, Other], - Callable[[Any], Other], Callable[[Iterator[Any]], Iterator[Other]], + Callable[[AsyncIterator[Any]], AsyncIterator[Other]], + Callable[[Any], Other], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSerializable[dict, Other]: @@ -134,8 +135,9 @@ class StructuredPrompt(ChatPromptTemplate): self, *others: Union[ Runnable[Any, Other], - Callable[[Any], Other], Callable[[Iterator[Any]], Iterator[Other]], + Callable[[AsyncIterator[Any]], AsyncIterator[Other]], + Callable[[Any], Other], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], name: Optional[str] = None, diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index de19c5f5915..b408ef8f116 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -565,8 +565,9 @@ class Runnable(Generic[Input, Output], ABC): self, other: Union[ Runnable[Any, Other], - Callable[[Any], Other], Callable[[Iterator[Any]], Iterator[Other]], + Callable[[AsyncIterator[Any]], AsyncIterator[Other]], + Callable[[Any], Other], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSerializable[Input, Other]: @@ -577,8 +578,9 @@ class Runnable(Generic[Input, Output], ABC): self, other: Union[ Runnable[Other, Any], - Callable[[Other], Any], Callable[[Iterator[Other]], Iterator[Any]], + Callable[[AsyncIterator[Other]], AsyncIterator[Any]], + Callable[[Other], Any], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], ) -> RunnableSerializable[Other, Output]: @@ -2960,8 +2962,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]): self, other: Union[ Runnable[Any, Other], - Callable[[Any], Other], Callable[[Iterator[Any]], Iterator[Other]], + Callable[[AsyncIterator[Any]], AsyncIterator[Other]], + Callable[[Any], Other], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSerializable[Input, Other]: @@ -2988,8 +2991,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]): self, other: Union[ Runnable[Other, Any], - Callable[[Other], Any], Callable[[Iterator[Other]], Iterator[Any]], + Callable[[AsyncIterator[Other]], AsyncIterator[Any]], + Callable[[Other], Any], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], ) -> RunnableSerializable[Other, Output]: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index a4b616964ed..a476975ac7c 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5253,7 +5253,7 @@ async def test_runnable_gen_transform() -> None: yield i + 1 chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one - achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one + achain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one assert chain.get_input_jsonschema() == { "title": "gen_indexes_input",