mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
Implement diff
This commit is contained in:
parent
5cbe2b7b6a
commit
4e28a7a513
@ -337,11 +337,17 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|||||||
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||||
"""Base class for an output parser that can handle streaming input."""
|
"""Base class for an output parser that can handle streaming input."""
|
||||||
|
|
||||||
|
diff: bool = False
|
||||||
|
|
||||||
|
def _diff(self, prev: Optional[T], next: T) -> T:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||||
|
prev_parsed = None
|
||||||
acc_gen = None
|
acc_gen = None
|
||||||
for chunk in input:
|
for chunk in input:
|
||||||
if isinstance(chunk, BaseMessageChunk):
|
if isinstance(chunk, BaseMessageChunk):
|
||||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||||
elif isinstance(chunk, BaseMessage):
|
elif isinstance(chunk, BaseMessage):
|
||||||
chunk_gen = ChatGenerationChunk(
|
chunk_gen = ChatGenerationChunk(
|
||||||
message=BaseMessageChunk(**chunk.dict())
|
message=BaseMessageChunk(**chunk.dict())
|
||||||
@ -355,16 +361,21 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
acc_gen += chunk_gen
|
acc_gen += chunk_gen
|
||||||
|
|
||||||
parsed = self.parse_result([acc_gen])
|
parsed = self.parse_result([acc_gen])
|
||||||
if parsed is not None:
|
if parsed is not None and parsed != prev_parsed:
|
||||||
yield parsed
|
if self.diff:
|
||||||
|
yield self._diff(prev_parsed, parsed)
|
||||||
|
else:
|
||||||
|
yield parsed
|
||||||
|
prev_parsed = parsed
|
||||||
|
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||||
) -> AsyncIterator[T]:
|
) -> AsyncIterator[T]:
|
||||||
|
prev_parsed = None
|
||||||
acc_gen = None
|
acc_gen = None
|
||||||
for chunk in input:
|
async for chunk in input:
|
||||||
if isinstance(chunk, BaseMessageChunk):
|
if isinstance(chunk, BaseMessageChunk):
|
||||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||||
elif isinstance(chunk, BaseMessage):
|
elif isinstance(chunk, BaseMessage):
|
||||||
chunk_gen = ChatGenerationChunk(
|
chunk_gen = ChatGenerationChunk(
|
||||||
message=BaseMessageChunk(**chunk.dict())
|
message=BaseMessageChunk(**chunk.dict())
|
||||||
@ -378,8 +389,12 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
acc_gen += chunk_gen
|
acc_gen += chunk_gen
|
||||||
|
|
||||||
parsed = self.parse_result([acc_gen])
|
parsed = self.parse_result([acc_gen])
|
||||||
if parsed is not None:
|
if parsed is not None and parsed != prev_parsed:
|
||||||
yield parsed
|
if self.diff:
|
||||||
|
yield self._diff(prev_parsed, parsed)
|
||||||
|
else:
|
||||||
|
yield parsed
|
||||||
|
prev_parsed = parsed
|
||||||
|
|
||||||
|
|
||||||
class StrOutputParser(BaseTransformOutputParser[str]):
|
class StrOutputParser(BaseTransformOutputParser[str]):
|
||||||
|
Loading…
Reference in New Issue
Block a user