From e86fd946c84bcf504d05d74aa97ea153b09ba0a9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 27 Jan 2024 08:09:29 -0800 Subject: [PATCH] In stream_event and stream_log handle closed streams (#16661) if eg. the stream iterator is interrupted then adding more events to the send_stream will raise an exception that we should catch (and handle where appropriate) --- .../core/langchain_core/tracers/log_stream.py | 102 +++++++++--------- 1 file changed, 52 insertions(+), 50 deletions(-) diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 72d7f590a2f..c6b8da943a3 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -20,7 +20,7 @@ from typing import ( from uuid import UUID import jsonpatch # type: ignore[import] -from anyio import create_memory_object_stream +from anyio import BrokenResourceError, ClosedResourceError, create_memory_object_stream from typing_extensions import NotRequired, TypedDict from langchain_core.load import dumps @@ -223,6 +223,14 @@ class LogStreamCallbackHandler(BaseTracer): def __aiter__(self) -> AsyncIterator[RunLogPatch]: return self.receive_stream.__aiter__() + def send(self, *ops: Dict[str, Any]) -> bool: + """Send a patch to the stream, return False if the stream is closed.""" + try: + self.send_stream.send_nowait(RunLogPatch(*ops)) + return True + except (ClosedResourceError, BrokenResourceError): + return False + async def tap_output_aiter( self, run_id: UUID, output: AsyncIterator[T] ) -> AsyncIterator[T]: @@ -233,15 +241,14 @@ class LogStreamCallbackHandler(BaseTracer): # if we can't find the run silently ignore # eg. because this run wasn't included in the log if key := self._key_map_by_run_id.get(run_id): - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "add", - "path": f"/logs/{key}/streamed_output/-", - "value": chunk, - } - ) - ) + if not self.send( + { + "op": "add", + "path": f"/logs/{key}/streamed_output/-", + "value": chunk, + } + ): + break yield chunk @@ -285,22 +292,21 @@ class LogStreamCallbackHandler(BaseTracer): """Start a run.""" if self.root_id is None: self.root_id = run.id - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "replace", - "path": "", - "value": RunState( - id=str(run.id), - streamed_output=[], - final_output=None, - logs={}, - name=run.name, - type=run.run_type, - ), - } - ) - ) + if not self.send( + { + "op": "replace", + "path": "", + "value": RunState( + id=str(run.id), + streamed_output=[], + final_output=None, + logs={}, + name=run.name, + type=run.run_type, + ), + } + ): + return if not self.include_run(run): return @@ -331,14 +337,12 @@ class LogStreamCallbackHandler(BaseTracer): entry["inputs"] = _get_standardized_inputs(run, self._schema_format) # Add the run to the stream - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "add", - "path": f"/logs/{self._key_map_by_run_id[run.id]}", - "value": entry, - } - ) + self.send( + { + "op": "add", + "path": f"/logs/{self._key_map_by_run_id[run.id]}", + "value": entry, + } ) def _on_run_update(self, run: Run) -> None: @@ -382,7 +386,7 @@ class LogStreamCallbackHandler(BaseTracer): ] ) - self.send_stream.send_nowait(RunLogPatch(*ops)) + self.send(*ops) finally: if run.id == self.root_id: if self.auto_close: @@ -400,21 +404,19 @@ class LogStreamCallbackHandler(BaseTracer): if index is None: return - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "add", - "path": f"/logs/{index}/streamed_output_str/-", - "value": token, - }, - { - "op": "add", - "path": f"/logs/{index}/streamed_output/-", - "value": chunk.message - if isinstance(chunk, ChatGenerationChunk) - else token, - }, - ) + self.send( + { + "op": "add", + "path": f"/logs/{index}/streamed_output_str/-", + "value": token, + }, + { + "op": "add", + "path": f"/logs/{index}/streamed_output/-", + "value": chunk.message + if isinstance(chunk, ChatGenerationChunk) + else token, + }, )