core: Improve Runnable __or__ method typing annotations (#31273)

* It is possible to chain a `Runnable` with an `AsyncIterator` as seen
in `test_runnable.py`.
* Iterator and AsyncIterator Input/Output of Callables must be put
before `Callable[[Other], Any]` otherwise the pattern matching picks the
latter.
This commit is contained in:
Christophe Bornet 2025-05-19 15:32:31 +02:00 committed by GitHub
parent e1af509966
commit 17c5a1621f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 8 deletions

View File

@ -1,6 +1,6 @@
"""Structured prompt template for a language model."""
from collections.abc import Iterator, Mapping, Sequence
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from typing import (
Any,
Callable,
@ -123,8 +123,9 @@ class StructuredPrompt(ChatPromptTemplate):
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Callable[[AsyncIterator[Any]], AsyncIterator[Other]],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSerializable[dict, Other]:
@ -134,8 +135,9 @@ class StructuredPrompt(ChatPromptTemplate):
self,
*others: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Callable[[AsyncIterator[Any]], AsyncIterator[Other]],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
name: Optional[str] = None,

View File

@ -565,8 +565,9 @@ class Runnable(Generic[Input, Output], ABC):
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Callable[[AsyncIterator[Any]], AsyncIterator[Other]],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSerializable[Input, Other]:
@ -577,8 +578,9 @@ class Runnable(Generic[Input, Output], ABC):
self,
other: Union[
Runnable[Other, Any],
Callable[[Other], Any],
Callable[[Iterator[Other]], Iterator[Any]],
Callable[[AsyncIterator[Other]], AsyncIterator[Any]],
Callable[[Other], Any],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSerializable[Other, Output]:
@ -2960,8 +2962,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Callable[[AsyncIterator[Any]], AsyncIterator[Other]],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSerializable[Input, Other]:
@ -2988,8 +2991,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
self,
other: Union[
Runnable[Other, Any],
Callable[[Other], Any],
Callable[[Iterator[Other]], Iterator[Any]],
Callable[[AsyncIterator[Other]], AsyncIterator[Any]],
Callable[[Other], Any],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSerializable[Other, Output]:

View File

@ -5253,7 +5253,7 @@ async def test_runnable_gen_transform() -> None:
yield i + 1
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
achain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
assert chain.get_input_jsonschema() == {
"title": "gen_indexes_input",