diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events.py b/libs/core/tests/unit_tests/runnables/test_runnable_events.py index 7394873f695..aa7ff7c1d4d 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events.py @@ -1654,3 +1654,124 @@ async def test_runnable_with_message_history() -> None: AIMessage(content="world", id="ai4"), ] } + + +EXPECTED_EVENTS = [ + { + "data": {"input": 1}, + "event": "on_chain_start", + "metadata": {}, + "name": "add_one_proxy", + "run_id": "", + "tags": [], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "add_one", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": 2}, + "event": "on_chain_stream", + "metadata": {}, + "name": "add_one", + "run_id": "", + "tags": [], + }, + { + "data": {"input": 1, "output": 2}, + "event": "on_chain_end", + "metadata": {}, + "name": "add_one", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": 2}, + "event": "on_chain_stream", + "metadata": {}, + "name": "add_one_proxy", + "run_id": "", + "tags": [], + }, + { + "data": {"output": 2}, + "event": "on_chain_end", + "metadata": {}, + "name": "add_one_proxy", + "run_id": "", + "tags": [], + }, +] + + +@pytest.mark.xfail( + reason="This test is failing due to missing functionality." + "Need to implement logic in _transform_stream_with_config that mimics the async " + "variant that uses tap_output_iter" +) +async def test_sync_in_async_stream_lambdas() -> None: + """Test invoking nested runnable lambda.""" + + def add_one_(x: int) -> int: + return x + 1 + + add_one = RunnableLambda(add_one_) + + async def add_one_proxy_(x: int, config: RunnableConfig) -> int: + streaming = add_one.stream(x, config) + results = [result for result in streaming] + return results[0] + + add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore + + events = await _collect_events(add_one_proxy.astream_events(1, version="v1")) + assert events == EXPECTED_EVENTS + + +async def test_async_in_async_stream_lambdas() -> None: + """Test invoking nested runnable lambda.""" + + async def add_one(x: int) -> int: + return x + 1 + + add_one_ = RunnableLambda(add_one) # type: ignore + + async def add_one_proxy(x: int, config: RunnableConfig) -> int: + # Use sync streaming + streaming = add_one_.astream(x, config) + results = [result async for result in streaming] + return results[0] + + add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore + + events = await _collect_events(add_one_proxy_.astream_events(1, version="v1")) + assert events == EXPECTED_EVENTS + + +@pytest.mark.xfail( + reason="This test is failing due to missing functionality." + "Need to implement logic in _transform_stream_with_config that mimics the async " + "variant that uses tap_output_iter" +) +async def test_sync_in_sync_lambdas() -> None: + """Test invoking nested runnable lambda.""" + + def add_one(x: int) -> int: + return x + 1 + + add_one_ = RunnableLambda(add_one) + + def add_one_proxy(x: int, config: RunnableConfig) -> int: + # Use sync streaming + streaming = add_one_.stream(x, config) + results = [result for result in streaming] + return results[0] + + add_one_proxy_ = RunnableLambda(add_one_proxy) + + events = await _collect_events(add_one_proxy_.astream_events(1, version="v1")) + assert events == EXPECTED_EVENTS