diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 782c06c59d5..a97e708b64b 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -1,9 +1,18 @@ from __future__ import annotations -from typing import List, Optional +from typing import AsyncIterator, Iterator, List, Optional from langchain.load.serializable import Serializable -from langchain.schema.runnable.base import Input, Runnable, RunnableConfig +from langchain.schema.runnable.base import Input, Runnable +from langchain.schema.runnable.config import RunnableConfig + + +def identity(x: Input) -> Input: + return x + + +async def aidentity(x: Input) -> Input: + return x class RunnablePassthrough(Serializable, Runnable[Input, Input]): @@ -20,4 +29,19 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): return self.__class__.__module__.split(".")[:-1] def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: - return self._call_with_config(lambda x: x, input, config) + return self._call_with_config(identity, input, config) + + async def ainvoke( + self, input: Input, config: RunnableConfig | None = None + ) -> Input: + return await self._acall_with_config(aidentity, input, config) + + def transform( + self, input: Iterator[Input], config: RunnableConfig | None = None + ) -> Iterator[Input]: + return self._transform_stream_with_config(input, identity, config) + + def atransform( + self, input: AsyncIterator[Input], config: RunnableConfig | None = None + ) -> AsyncIterator[Input]: + return self._atransform_stream_with_config(input, identity, config) diff --git a/libs/langchain/tests/unit_tests/schema/test_runnable.py b/libs/langchain/tests/unit_tests/schema/test_runnable.py index 4ef0e33d7d4..aa0b7f4c1ee 100644 --- a/libs/langchain/tests/unit_tests/schema/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/test_runnable.py @@ -784,6 +784,13 @@ def test_deep_stream() -> None: assert len(chunks) == len("foo-lish") assert "".join(chunks) == "foo-lish" + chunks = [] + for chunk in (chain | RunnablePassthrough()).stream({"question": "What up"}): + chunks.append(chunk) + + assert len(chunks) == len("foo-lish") + assert "".join(chunks) == "foo-lish" + @pytest.mark.asyncio async def test_deep_astream() -> None: @@ -804,6 +811,13 @@ async def test_deep_astream() -> None: assert len(chunks) == len("foo-lish") assert "".join(chunks) == "foo-lish" + chunks = [] + async for chunk in (chain | RunnablePassthrough()).astream({"question": "What up"}): + chunks.append(chunk) + + assert len(chunks) == len("foo-lish") + assert "".join(chunks) == "foo-lish" + @pytest.fixture() def llm_with_fallbacks() -> RunnableWithFallbacks: