diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index cdfa09f7294..5dba5b9c8f4 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -2,10 +2,12 @@ from __future__ import annotations import asyncio +import inspect import threading from typing import ( Any, AsyncIterator, + Awaitable, Callable, Dict, Iterator, @@ -100,6 +102,26 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]): input_type: Optional[Type[Input]] = None + func: Optional[Callable[[Input], None]] = None + + afunc: Optional[Callable[[Input], Awaitable[None]]] = None + + def __init__( + self, + func: Optional[ + Union[Callable[[Input], None], Callable[[Input], Awaitable[None]]] + ] = None, + afunc: Optional[Callable[[Input], Awaitable[None]]] = None, + *, + input_type: Optional[Type[Input]] = None, + **kwargs: Any, + ) -> None: + if inspect.iscoroutinefunction(func): + afunc = func + func = None + + super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) + @classmethod def is_lc_serializable(cls) -> bool: return True @@ -140,6 +162,8 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]): return RunnableAssign(RunnableParallel(kwargs)) def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: + if self.func is not None: + self.func(input) return self._call_with_config(identity, input, config) async def ainvoke( @@ -148,6 +172,10 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Input: + if self.afunc is not None: + await self.afunc(input, **kwargs) + elif self.func is not None: + self.func(input, **kwargs) return await self._acall_with_config(aidentity, input, config) def transform( @@ -156,7 +184,21 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Iterator[Input]: - return self._transform_stream_with_config(input, identity, config) + if self.func is None: + for chunk in self._transform_stream_with_config(input, identity, config): + yield chunk + else: + final = None + + for chunk in self._transform_stream_with_config(input, identity, config): + yield chunk + if final is None: + final = chunk + else: + final = final + chunk + + if final is not None: + self.func(final, **kwargs) async def atransform( self, @@ -164,7 +206,47 @@ class RunnablePassthrough(RunnableSerializable[Input, Input]): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> AsyncIterator[Input]: - async for chunk in self._atransform_stream_with_config(input, identity, config): + if self.afunc is None and self.func is None: + async for chunk in self._atransform_stream_with_config( + input, identity, config + ): + yield chunk + else: + final = None + + async for chunk in self._atransform_stream_with_config( + input, identity, config + ): + yield chunk + if final is None: + final = chunk + else: + final = final + chunk + + if final is not None: + if self.afunc is not None: + await self.afunc(final, **kwargs) + elif self.func is not None: + self.func(final, **kwargs) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Input]: + return self.transform(iter([input]), config, **kwargs) + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Input]: + async def input_aiter() -> AsyncIterator[Input]: + yield input + + async for chunk in self.atransform(input_aiter(), config, **kwargs): yield chunk diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 4a7c6383598..b4b607f3fa2 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -3281,7 +3281,11 @@ "runnable", "RunnablePassthrough" ], - "kwargs": {} + "kwargs": { + "func": null, + "afunc": null, + "input_type": null + } }, "last": { "lc": 1, diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 20e316dbb00..944a0eb5257 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1016,6 +1016,34 @@ def test_configurable_fields_example() -> None: ) +@pytest.mark.asyncio +async def test_passthrough_tap_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + mock = mocker.Mock() + + seq: Runnable = fake | RunnablePassthrough(mock) + + assert await seq.ainvoke("hello") == 5 + assert mock.call_args_list == [mocker.call(5)] + mock.reset_mock() + + assert [ + part async for part in seq.astream("hello", dict(metadata={"key": "value"})) + ] == [5] + assert mock.call_args_list == [mocker.call(5)] + mock.reset_mock() + + assert seq.invoke("hello") == 5 + assert mock.call_args_list == [mocker.call(5)] + mock.reset_mock() + + assert [part for part in seq.stream("hello", dict(metadata={"key": "value"}))] == [ + 5 + ] + assert mock.call_args_list == [mocker.call(5)] + mock.reset_mock() + + @pytest.mark.asyncio async def test_with_config(mocker: MockerFixture) -> None: fake = FakeRunnable()