mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 03:38:06 +00:00
Implement diff
This commit is contained in:
parent
5cbe2b7b6a
commit
4e28a7a513
@ -337,11 +337,17 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
"""Base class for an output parser that can handle streaming input."""
|
||||
|
||||
diff: bool = False
|
||||
|
||||
def _diff(self, prev: Optional[T], next: T) -> T:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||
prev_parsed = None
|
||||
acc_gen = None
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
@ -355,16 +361,21 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
acc_gen += chunk_gen
|
||||
|
||||
parsed = self.parse_result([acc_gen])
|
||||
if parsed is not None:
|
||||
yield parsed
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield self._diff(prev_parsed, parsed)
|
||||
else:
|
||||
yield parsed
|
||||
prev_parsed = parsed
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
prev_parsed = None
|
||||
acc_gen = None
|
||||
for chunk in input:
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
@ -378,8 +389,12 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
acc_gen += chunk_gen
|
||||
|
||||
parsed = self.parse_result([acc_gen])
|
||||
if parsed is not None:
|
||||
yield parsed
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield self._diff(prev_parsed, parsed)
|
||||
else:
|
||||
yield parsed
|
||||
prev_parsed = parsed
|
||||
|
||||
|
||||
class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
|
Loading…
Reference in New Issue
Block a user