core: Tap output of sync iterators for astream_events (#21842)

Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
This commit is contained in:
Nuno Campos 2024-05-17 16:57:41 -07:00 committed by GitHub
parent 9a39f92aba
commit b1e7b40b6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 78 additions and 30 deletions

View File

@ -1716,6 +1716,9 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Iterator of Input values into an Iterator of """Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks. Output values, with callbacks.
Use this to implement `stream()` or `transform()` in Runnable subclasses.""" Use this to implement `stream()` or `transform()` in Runnable subclasses."""
# Mixin that is used by both astream log and astream events implementation
from langchain_core.tracers._streaming import _StreamingCallbackHandler
# tee the input so we can iterate over it twice # tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(input, 2) input_for_tracing, input_for_transform = tee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one # Start the input iterator to ensure the input runnable starts before this one
@ -1742,6 +1745,17 @@ class Runnable(Generic[Input, Output], ABC):
context = copy_context() context = copy_context()
context.run(var_child_runnable_config.set, child_config) context.run(var_child_runnable_config.set, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(
(
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
# instance check OK here, it's a mixin
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator)
try: try:
while True: while True:
chunk: Output = context.run(next, iterator) # type: ignore chunk: Output = context.run(next, iterator) # type: ignore

View File

@ -1,6 +1,6 @@
"""Internal tracers used for stream_log and astream events implementations.""" """Internal tracers used for stream_log and astream events implementations."""
import abc import abc
from typing import AsyncIterator, TypeVar from typing import AsyncIterator, Iterator, TypeVar
from uuid import UUID from uuid import UUID
T = TypeVar("T") T = TypeVar("T")
@ -22,6 +22,10 @@ class _StreamingCallbackHandler(abc.ABC):
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
"""Used for internal astream_log and astream events implementations.""" """Used for internal astream_log and astream events implementations."""
@abc.abstractmethod
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Used for internal astream_log and astream events implementations."""
__all__ = [ __all__ = [
"_StreamingCallbackHandler", "_StreamingCallbackHandler",

View File

@ -9,6 +9,7 @@ from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
Dict, Dict,
Iterator,
List, List,
Optional, Optional,
Sequence, Sequence,
@ -102,10 +103,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self.send_stream = memory_stream.get_send_stream() self.send_stream = memory_stream.get_send_stream()
self.receive_stream = memory_stream.get_receive_stream() self.receive_stream = memory_stream.get_receive_stream()
async def _send(self, event: StreamEvent, event_type: str) -> None: def _send(self, event: StreamEvent, event_type: str) -> None:
"""Send an event to the stream.""" """Send an event to the stream."""
if self.root_event_filter.include_event(event, event_type): if self.root_event_filter.include_event(event, event_type):
await self.send_stream.send(event) self.send_stream.send_nowait(event)
def __aiter__(self) -> AsyncIterator[Any]: def __aiter__(self) -> AsyncIterator[Any]:
"""Iterate over the receive stream.""" """Iterate over the receive stream."""
@ -119,7 +120,26 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_info = self.run_map.get(run_id) run_info = self.run_map.get(run_id)
if run_info is None: if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.") raise AssertionError(f"Run ID {run_id} not found in run map.")
await self._send( self._send(
{
"event": f"on_{run_info['run_type']}_stream",
"data": {"chunk": chunk},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
run_info["run_type"],
)
yield chunk
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Tap the output aiter."""
for chunk in output:
run_info = self.run_map.get(run_id)
if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.")
self._send(
{ {
"event": f"on_{run_info['run_type']}_stream", "event": f"on_{run_info['run_type']}_stream",
"data": {"chunk": chunk}, "data": {"chunk": chunk},
@ -155,7 +175,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"inputs": {"messages": messages}, "inputs": {"messages": messages},
} }
await self._send( self._send(
{ {
"event": "on_chat_model_start", "event": "on_chat_model_start",
"data": { "data": {
@ -192,7 +212,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"inputs": {"prompts": prompts}, "inputs": {"prompts": prompts},
} }
await self._send( self._send(
{ {
"event": "on_llm_start", "event": "on_llm_start",
"data": { "data": {
@ -241,7 +261,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
else: else:
raise ValueError(f"Unexpected run type: {run_info['run_type']}") raise ValueError(f"Unexpected run type: {run_info['run_type']}")
await self._send( self._send(
{ {
"event": event, "event": event,
"data": { "data": {
@ -295,7 +315,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
else: else:
raise ValueError(f"Unexpected run type: {run_info['run_type']}") raise ValueError(f"Unexpected run type: {run_info['run_type']}")
await self._send( self._send(
{ {
"event": event, "event": event,
"data": {"output": output, "input": inputs_}, "data": {"output": output, "input": inputs_},
@ -340,7 +360,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self.run_map[run_id] = run_info self.run_map[run_id] = run_info
await self._send( self._send(
{ {
"event": f"on_{run_type_}_start", "event": f"on_{run_type_}_start",
"data": data, "data": data,
@ -373,7 +393,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"input": inputs, "input": inputs,
} }
await self._send( self._send(
{ {
"event": event, "event": event,
"data": data, "data": data,
@ -408,7 +428,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"inputs": inputs, "inputs": inputs,
} }
await self._send( self._send(
{ {
"event": "on_tool_start", "event": "on_tool_start",
"data": { "data": {
@ -432,7 +452,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
) )
inputs = run_info["inputs"] inputs = run_info["inputs"]
await self._send( self._send(
{ {
"event": "on_tool_end", "event": "on_tool_end",
"data": { "data": {
@ -470,7 +490,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"inputs": {"query": query}, "inputs": {"query": query},
} }
await self._send( self._send(
{ {
"event": "on_retriever_start", "event": "on_retriever_start",
"data": { "data": {
@ -492,7 +512,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"""Run when Retriever ends running.""" """Run when Retriever ends running."""
run_info = self.run_map.pop(run_id) run_info = self.run_map.pop(run_id)
await self._send( self._send(
{ {
"event": "on_retriever_end", "event": "on_retriever_end",
"data": { "data": {

View File

@ -8,6 +8,7 @@ from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
Dict, Dict,
Iterator,
List, List,
Literal, Literal,
Optional, Optional,
@ -252,6 +253,25 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
yield chunk yield chunk
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Tap an output async iterator to stream its values to the log."""
for chunk in output:
# root run is handled in .astream_log()
if run_id != self.root_id:
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if key := self._key_map_by_run_id.get(run_id):
if not self.send(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
):
break
yield chunk
def include_run(self, run: Run) -> bool: def include_run(self, run: Run) -> bool:
if run.id == self.root_id: if run.id == self.root_id:
return False return False

View File

@ -1650,27 +1650,22 @@ EXPECTED_EVENTS = [
] ]
@pytest.mark.xfail(
reason="This test is failing due to missing functionality."
"Need to implement logic in _transform_stream_with_config that mimics the async "
"variant that uses tap_output_iter"
)
async def test_sync_in_async_stream_lambdas() -> None: async def test_sync_in_async_stream_lambdas() -> None:
"""Test invoking nested runnable lambda.""" """Test invoking nested runnable lambda."""
def add_one_(x: int) -> int: def add_one(x: int) -> int:
return x + 1 return x + 1
add_one = RunnableLambda(add_one_) add_one_ = RunnableLambda(add_one)
async def add_one_proxy_(x: int, config: RunnableConfig) -> int: async def add_one_proxy(x: int, config: RunnableConfig) -> int:
streaming = add_one.stream(x, config) streaming = add_one_.stream(x, config)
results = [result for result in streaming] results = [result for result in streaming]
return results[0] return results[0]
add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore
events = await _collect_events(add_one_proxy.astream_events(1, version="v2")) events = await _collect_events(add_one_proxy_.astream_events(1, version="v2"))
assert events == EXPECTED_EVENTS assert events == EXPECTED_EVENTS
@ -1694,11 +1689,6 @@ async def test_async_in_async_stream_lambdas() -> None:
assert events == EXPECTED_EVENTS assert events == EXPECTED_EVENTS
@pytest.mark.xfail(
reason="This test is failing due to missing functionality."
"Need to implement logic in _transform_stream_with_config that mimics the async "
"variant that uses tap_output_iter"
)
async def test_sync_in_sync_lambdas() -> None: async def test_sync_in_sync_lambdas() -> None:
"""Test invoking nested runnable lambda.""" """Test invoking nested runnable lambda."""