diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 5f09ee48ae6..89a065e9ad6 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -337,11 +337,17 @@ class BaseTransformOutputParser(BaseOutputParser[T]): class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): """Base class for an output parser that can handle streaming input.""" + diff: bool = False + + def _diff(self, prev: Optional[T], next: T) -> T: + raise NotImplementedError() + def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: + prev_parsed = None acc_gen = None for chunk in input: if isinstance(chunk, BaseMessageChunk): - chunk_gen = ChatGenerationChunk(message=chunk) + chunk_gen: Generation = ChatGenerationChunk(message=chunk) elif isinstance(chunk, BaseMessage): chunk_gen = ChatGenerationChunk( message=BaseMessageChunk(**chunk.dict()) @@ -355,16 +361,21 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): acc_gen += chunk_gen parsed = self.parse_result([acc_gen]) - if parsed is not None: - yield parsed + if parsed is not None and parsed != prev_parsed: + if self.diff: + yield self._diff(prev_parsed, parsed) + else: + yield parsed + prev_parsed = parsed async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[T]: + prev_parsed = None acc_gen = None - for chunk in input: + async for chunk in input: if isinstance(chunk, BaseMessageChunk): - chunk_gen = ChatGenerationChunk(message=chunk) + chunk_gen: Generation = ChatGenerationChunk(message=chunk) elif isinstance(chunk, BaseMessage): chunk_gen = ChatGenerationChunk( message=BaseMessageChunk(**chunk.dict()) @@ -378,8 +389,12 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): acc_gen += chunk_gen parsed = self.parse_result([acc_gen]) - if parsed is not None: - yield parsed + if parsed is not None and parsed != prev_parsed: + if self.diff: + yield self._diff(prev_parsed, parsed) + else: + yield parsed + prev_parsed = parsed class StrOutputParser(BaseTransformOutputParser[str]):