core[patch]: Replace memory stream implementation used by LogStreamCallbackHandler (#17185)

This PR replaces the memory stream implementation used by the 
LogStreamCallbackHandler.

This implementation resolves an issue in which streamed logs and
streamed events originating from sync code would arrive only after the
entire sync code would finish execution (rather than arriving in real
time as they're generated).

One example is if trying to stream tokens from an llm within a tool. If
the tool was an async tool, but the llm was invoked via stream (sync
variant) rather than astream (async variant), then the tokens would fail
to stream in real time and would all arrived bunched up after the tool
invocation completed.
This commit is contained in:
Eugene Yurtsev
2024-02-12 21:57:38 -05:00
committed by GitHub
parent 37ef6ac113
commit 93472ee9e6
4 changed files with 239 additions and 14 deletions

View File

@@ -738,7 +738,7 @@ def test_validation_error_handling_callable() -> None:
],
)
def test_validation_error_handling_non_validation_error(
handler: Union[bool, str, Callable[[ValidationError], str]]
handler: Union[bool, str, Callable[[ValidationError], str]],
) -> None:
"""Test that validation errors are handled correctly."""
@@ -800,7 +800,7 @@ async def test_async_validation_error_handling_callable() -> None:
],
)
async def test_async_validation_error_handling_non_validation_error(
handler: Union[bool, str, Callable[[ValidationError], str]]
handler: Union[bool, str, Callable[[ValidationError], str]],
) -> None:
"""Test that validation errors are handled correctly."""

View File

@@ -0,0 +1,122 @@
import asyncio
import math
import time
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncIterator
from langchain_core.tracers.memory_stream import _MemoryStream
async def test_same_event_loop() -> None:
"""Test that the memory stream works when the same event loop is used.
This is the easy case.
"""
reader_loop = asyncio.get_event_loop()
channel = _MemoryStream[dict](reader_loop)
writer = channel.get_send_stream()
reader = channel.get_receive_stream()
async def producer() -> None:
"""Produce items with slight delay."""
tic = time.time()
for i in range(3):
await asyncio.sleep(0.10)
toc = time.time()
await writer.send(
{
"item": i,
"produce_time": toc - tic,
}
)
await writer.aclose()
async def consumer() -> AsyncIterator[dict]:
tic = time.time()
async for item in reader:
toc = time.time()
yield {
"receive_time": toc - tic,
**item,
}
asyncio.create_task(producer())
items = [item async for item in consumer()]
for item in items:
delta_time = item["receive_time"] - item["produce_time"]
# Allow a generous 10ms of delay
# The test is meant to verify that the producer and consumer are running in
# parallel despite the fact that the producer is running from another thread.
# abs_tol is used to allow for some delay in the producer and consumer
# due to overhead.
# To verify that the producer and consumer are running in parallel, we
# expect the delta_time to be smaller than the sleep delay in the producer
# * # of items = 30 ms
assert (
math.isclose(delta_time, 0, abs_tol=0.010) is True
), f"delta_time: {delta_time}"
async def test_queue_for_streaming_via_sync_call() -> None:
"""Test via async -> sync -> async path."""
reader_loop = asyncio.get_event_loop()
channel = _MemoryStream[dict](reader_loop)
writer = channel.get_send_stream()
reader = channel.get_receive_stream()
async def producer() -> None:
"""Produce items with slight delay."""
tic = time.time()
for i in range(3):
await asyncio.sleep(0.10)
toc = time.time()
await writer.send(
{
"item": i,
"produce_time": toc - tic,
}
)
await writer.aclose()
def sync_call() -> None:
"""Blocking sync call."""
asyncio.run(producer())
async def consumer() -> AsyncIterator[dict]:
tic = time.time()
async for item in reader:
toc = time.time()
yield {
"receive_time": toc - tic,
**item,
}
with ThreadPoolExecutor() as executor:
executor.submit(sync_call)
items = [item async for item in consumer()]
for item in items:
delta_time = item["receive_time"] - item["produce_time"]
# Allow a generous 10ms of delay
# The test is meant to verify that the producer and consumer are running in
# parallel despite the fact that the producer is running from another thread.
# abs_tol is used to allow for some delay in the producer and consumer
# due to overhead.
# To verify that the producer and consumer are running in parallel, we
# expect the delta_time to be smaller than the sleep delay in the producer
# * # of items = 30 ms
assert (
math.isclose(delta_time, 0, abs_tol=0.010) is True
), f"delta_time: {delta_time}"
async def test_closed_stream() -> None:
reader_loop = asyncio.get_event_loop()
channel = _MemoryStream[str](reader_loop)
writer = channel.get_send_stream()
reader = channel.get_receive_stream()
await writer.aclose()
assert [chunk async for chunk in reader] == []