Implement streaming for all list output parsers (#14981)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-21 11:30:35 -08:00 committed by GitHub
parent b471166df7
commit 63e512b680
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 365 additions and 58 deletions

View File

@ -2,12 +2,25 @@ from __future__ import annotations
import re
from abc import abstractmethod
from typing import List
from collections import deque
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
T = TypeVar("T")
class ListOutputParser(BaseOutputParser[List[str]]):
def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
"""Drop the last n elements of an iterator."""
buffer: Deque[T] = deque()
for item in iter:
buffer.append(item)
if len(buffer) > n:
yield buffer.popleft()
class ListOutputParser(BaseTransformOutputParser[List[str]]):
"""Parse the output of an LLM call to a list."""
@property
@ -18,6 +31,74 @@ class ListOutputParser(BaseOutputParser[List[str]]):
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call."""
raise NotImplementedError
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[List[str]]:
buffer = ""
for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# add current chunk to buffer
buffer += chunk
# parse buffer into a list of parts
try:
done_idx = 0
# yield only complete parts
for m in droplastn(self.parse_iter(buffer), 1):
done_idx = m.end()
yield [m.group(1)]
buffer = buffer[done_idx:]
except NotImplementedError:
parts = self.parse(buffer)
# yield only complete parts
if len(parts) > 1:
for part in parts[:-1]:
yield [part]
buffer = parts[-1]
# yield the last part
for part in self.parse(buffer):
yield [part]
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[List[str]]:
buffer = ""
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# add current chunk to buffer
buffer += chunk
# parse buffer into a list of parts
try:
done_idx = 0
# yield only complete parts
for m in droplastn(self.parse_iter(buffer), 1):
done_idx = m.end()
yield [m.group(1)]
buffer = buffer[done_idx:]
except NotImplementedError:
parts = self.parse(buffer)
# yield only complete parts
if len(parts) > 1:
for part in parts[:-1]:
yield [part]
buffer = parts[-1]
# yield the last part
for part in self.parse(buffer):
yield [part]
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list."""
@ -49,6 +130,8 @@ class CommaSeparatedListOutputParser(ListOutputParser):
class NumberedListOutputParser(ListOutputParser):
"""Parse a numbered list."""
pattern = r"\d+\.\s([^\n]+)"
def get_format_instructions(self) -> str:
return (
"Your response should be a numbered list with each item on a new line. "
@ -57,11 +140,11 @@ class NumberedListOutputParser(ListOutputParser):
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
pattern = r"\d+\.\s([^\n]+)"
return re.findall(self.pattern, text)
# Extract the text of each item
matches = re.findall(pattern, text)
return matches
def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call."""
return re.finditer(self.pattern, text)
@property
def _type(self) -> str:
@ -71,13 +154,18 @@ class NumberedListOutputParser(ListOutputParser):
class MarkdownListOutputParser(ListOutputParser):
"""Parse a markdown list."""
pattern = r"-\s([^\n]+)"
def get_format_instructions(self) -> str:
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
pattern = r"-\s([^\n]+)"
return re.findall(pattern, text)
return re.findall(self.pattern, text)
def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call."""
return re.finditer(self.pattern, text)
@property
def _type(self) -> str:

View File

@ -0,0 +1,268 @@
from typing import AsyncIterator, Iterable, List, TypeVar, cast
from langchain_core.output_parsers.list import (
CommaSeparatedListOutputParser,
MarkdownListOutputParser,
NumberedListOutputParser,
)
from langchain_core.runnables.utils import aadd, add
def test_single_item() -> None:
"""Test that a string with a single item is parsed to a list with that item."""
parser = CommaSeparatedListOutputParser()
text = "foo"
expected = ["foo"]
assert parser.parse(text) == expected
assert add(parser.transform(t for t in text)) == expected
assert list(parser.transform(t for t in text)) == [[a] for a in expected]
assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [
[a] for a in expected
]
assert list(
parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" ")))
) == [[a] for a in expected]
assert list(parser.transform(iter([text]))) == [[a] for a in expected]
def test_multiple_items() -> None:
"""Test that a string with multiple comma-separated items is parsed to a list."""
parser = CommaSeparatedListOutputParser()
text = "foo, bar, baz"
expected = ["foo", "bar", "baz"]
assert parser.parse(text) == expected
assert add(parser.transform(t for t in text)) == expected
assert list(parser.transform(t for t in text)) == [[a] for a in expected]
assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [
[a] for a in expected
]
assert list(
parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" ")))
) == [[a] for a in expected]
assert list(parser.transform(iter([text]))) == [[a] for a in expected]
def test_numbered_list() -> None:
parser = NumberedListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line. "
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
)
text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry"
text3 = "No items in the list."
for text, expected in [
(text1, ["foo", "bar", "baz"]),
(text2, ["apple", "banana", "cherry"]),
(text3, []),
]:
expectedlist = [[a] for a in cast(List[str], expected)]
assert parser.parse(text) == expected
assert add(parser.transform(t for t in text)) == (expected or None)
assert list(parser.transform(t for t in text)) == expectedlist
assert (
list(parser.transform(t for t in text.splitlines(keepends=True)))
== expectedlist
)
assert (
list(
parser.transform(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
== expectedlist
)
assert list(parser.transform(iter([text]))) == expectedlist
def test_markdown_list() -> None:
parser = MarkdownListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line."
"For example: \n- foo\n- bar\n- baz"
)
text2 = "Items:\n- apple\n- banana\n- cherry"
text3 = "No items in the list."
for text, expected in [
(text1, ["foo", "bar", "baz"]),
(text2, ["apple", "banana", "cherry"]),
(text3, []),
]:
expectedlist = [[a] for a in cast(List[str], expected)]
assert parser.parse(text) == expected
assert add(parser.transform(t for t in text)) == (expected or None)
assert list(parser.transform(t for t in text)) == expectedlist
assert (
list(parser.transform(t for t in text.splitlines(keepends=True)))
== expectedlist
)
assert (
list(
parser.transform(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
== expectedlist
)
assert list(parser.transform(iter([text]))) == expectedlist
T = TypeVar("T")
async def aiter_from_iter(iterable: Iterable[T]) -> AsyncIterator[T]:
for item in iterable:
yield item
async def test_single_item_async() -> None:
"""Test that a string with a single item is parsed to a list with that item."""
parser = CommaSeparatedListOutputParser()
text = "foo"
expected = ["foo"]
assert await parser.aparse(text) == expected
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == expected
assert [a async for a in parser.atransform(aiter_from_iter(t for t in text))] == [
[a] for a in expected
]
assert [
a
async for a in parser.atransform(
aiter_from_iter(t for t in text.splitlines(keepends=True))
)
] == [[a] for a in expected]
assert [
a
async for a in parser.atransform(
aiter_from_iter(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
] == [[a] for a in expected]
assert [a async for a in parser.atransform(aiter_from_iter([text]))] == [
[a] for a in expected
]
async def test_multiple_items_async() -> None:
"""Test that a string with multiple comma-separated items is parsed to a list."""
parser = CommaSeparatedListOutputParser()
text = "foo, bar, baz"
expected = ["foo", "bar", "baz"]
assert await parser.aparse(text) == expected
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == expected
assert [a async for a in parser.atransform(aiter_from_iter(t for t in text))] == [
[a] for a in expected
]
assert [
a
async for a in parser.atransform(
aiter_from_iter(t for t in text.splitlines(keepends=True))
)
] == [[a] for a in expected]
assert [
a
async for a in parser.atransform(
aiter_from_iter(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
] == [[a] for a in expected]
assert [a async for a in parser.atransform(aiter_from_iter([text]))] == [
[a] for a in expected
]
async def test_numbered_list_async() -> None:
parser = NumberedListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line. "
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
)
text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry"
text3 = "No items in the list."
for text, expected in [
(text1, ["foo", "bar", "baz"]),
(text2, ["apple", "banana", "cherry"]),
(text3, []),
]:
expectedlist = [[a] for a in cast(List[str], expected)]
assert await parser.aparse(text) == expected
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == (
expected or None
)
assert [
a async for a in parser.atransform(aiter_from_iter(t for t in text))
] == expectedlist
assert [
a
async for a in parser.atransform(
aiter_from_iter(t for t in text.splitlines(keepends=True))
)
] == expectedlist
assert [
a
async for a in parser.atransform(
aiter_from_iter(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
] == expectedlist
assert [
a async for a in parser.atransform(aiter_from_iter([text]))
] == expectedlist
async def test_markdown_list_async() -> None:
parser = MarkdownListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line."
"For example: \n- foo\n- bar\n- baz"
)
text2 = "Items:\n- apple\n- banana\n- cherry"
text3 = "No items in the list."
for text, expected in [
(text1, ["foo", "bar", "baz"]),
(text2, ["apple", "banana", "cherry"]),
(text3, []),
]:
expectedlist = [[a] for a in cast(List[str], expected)]
assert await parser.aparse(text) == expected
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == (
expected or None
)
assert [
a async for a in parser.atransform(aiter_from_iter(t for t in text))
] == expectedlist
assert [
a
async for a in parser.atransform(
aiter_from_iter(t for t in text.splitlines(keepends=True))
)
] == expectedlist
assert [
a
async for a in parser.atransform(
aiter_from_iter(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
] == expectedlist
assert [
a async for a in parser.atransform(aiter_from_iter([text]))
] == expectedlist

View File

@ -1,49 +0,0 @@
from langchain.output_parsers.list import (
CommaSeparatedListOutputParser,
MarkdownListOutputParser,
NumberedListOutputParser,
)
def test_single_item() -> None:
"""Test that a string with a single item is parsed to a list with that item."""
parser = CommaSeparatedListOutputParser()
assert parser.parse("foo") == ["foo"]
def test_multiple_items() -> None:
"""Test that a string with multiple comma-separated items is parsed to a list."""
parser = CommaSeparatedListOutputParser()
assert parser.parse("foo, bar, baz") == ["foo", "bar", "baz"]
def test_numbered_list() -> None:
parser = NumberedListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line. "
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
)
text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry"
text3 = "No items in the list."
assert parser.parse(text1) == ["foo", "bar", "baz"]
assert parser.parse(text2) == ["apple", "banana", "cherry"]
assert parser.parse(text3) == []
def test_markdown_list() -> None:
parser = MarkdownListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line."
"For example: \n- foo\n- bar\n- baz"
)
text2 = "Items:\n- apple\n- banana\n- cherry"
text3 = "No items in the list."
assert parser.parse(text1) == ["foo", "bar", "baz"]
assert parser.parse(text2) == ["apple", "banana", "cherry"]
assert parser.parse(text3) == []