Implement streaming for all list output parsers (#14981)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-21 11:30:35 -08:00 committed by GitHub
parent b471166df7
commit 63e512b680
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 365 additions and 58 deletions

View File

@ -2,12 +2,25 @@ from __future__ import annotations
import re import re
from abc import abstractmethod from abc import abstractmethod
from typing import List from collections import deque
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
from langchain_core.output_parsers.base import BaseOutputParser from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
T = TypeVar("T")
class ListOutputParser(BaseOutputParser[List[str]]): def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
"""Drop the last n elements of an iterator."""
buffer: Deque[T] = deque()
for item in iter:
buffer.append(item)
if len(buffer) > n:
yield buffer.popleft()
class ListOutputParser(BaseTransformOutputParser[List[str]]):
"""Parse the output of an LLM call to a list.""" """Parse the output of an LLM call to a list."""
@property @property
@ -18,6 +31,74 @@ class ListOutputParser(BaseOutputParser[List[str]]):
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call.""" """Parse the output of an LLM call."""
def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call."""
raise NotImplementedError
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[List[str]]:
buffer = ""
for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# add current chunk to buffer
buffer += chunk
# parse buffer into a list of parts
try:
done_idx = 0
# yield only complete parts
for m in droplastn(self.parse_iter(buffer), 1):
done_idx = m.end()
yield [m.group(1)]
buffer = buffer[done_idx:]
except NotImplementedError:
parts = self.parse(buffer)
# yield only complete parts
if len(parts) > 1:
for part in parts[:-1]:
yield [part]
buffer = parts[-1]
# yield the last part
for part in self.parse(buffer):
yield [part]
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[List[str]]:
buffer = ""
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# add current chunk to buffer
buffer += chunk
# parse buffer into a list of parts
try:
done_idx = 0
# yield only complete parts
for m in droplastn(self.parse_iter(buffer), 1):
done_idx = m.end()
yield [m.group(1)]
buffer = buffer[done_idx:]
except NotImplementedError:
parts = self.parse(buffer)
# yield only complete parts
if len(parts) > 1:
for part in parts[:-1]:
yield [part]
buffer = parts[-1]
# yield the last part
for part in self.parse(buffer):
yield [part]
class CommaSeparatedListOutputParser(ListOutputParser): class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list.""" """Parse the output of an LLM call to a comma-separated list."""
@ -49,6 +130,8 @@ class CommaSeparatedListOutputParser(ListOutputParser):
class NumberedListOutputParser(ListOutputParser): class NumberedListOutputParser(ListOutputParser):
"""Parse a numbered list.""" """Parse a numbered list."""
pattern = r"\d+\.\s([^\n]+)"
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return ( return (
"Your response should be a numbered list with each item on a new line. " "Your response should be a numbered list with each item on a new line. "
@ -57,11 +140,11 @@ class NumberedListOutputParser(ListOutputParser):
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call.""" """Parse the output of an LLM call."""
pattern = r"\d+\.\s([^\n]+)" return re.findall(self.pattern, text)
# Extract the text of each item def parse_iter(self, text: str) -> Iterator[re.Match]:
matches = re.findall(pattern, text) """Parse the output of an LLM call."""
return matches return re.finditer(self.pattern, text)
@property @property
def _type(self) -> str: def _type(self) -> str:
@ -71,13 +154,18 @@ class NumberedListOutputParser(ListOutputParser):
class MarkdownListOutputParser(ListOutputParser): class MarkdownListOutputParser(ListOutputParser):
"""Parse a markdown list.""" """Parse a markdown list."""
pattern = r"-\s([^\n]+)"
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`" return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call.""" """Parse the output of an LLM call."""
pattern = r"-\s([^\n]+)" return re.findall(self.pattern, text)
return re.findall(pattern, text)
def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call."""
return re.finditer(self.pattern, text)
@property @property
def _type(self) -> str: def _type(self) -> str:

View File

@ -0,0 +1,268 @@
from typing import AsyncIterator, Iterable, List, TypeVar, cast
from langchain_core.output_parsers.list import (
CommaSeparatedListOutputParser,
MarkdownListOutputParser,
NumberedListOutputParser,
)
from langchain_core.runnables.utils import aadd, add
def test_single_item() -> None:
"""Test that a string with a single item is parsed to a list with that item."""
parser = CommaSeparatedListOutputParser()
text = "foo"
expected = ["foo"]
assert parser.parse(text) == expected
assert add(parser.transform(t for t in text)) == expected
assert list(parser.transform(t for t in text)) == [[a] for a in expected]
assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [
[a] for a in expected
]
assert list(
parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" ")))
) == [[a] for a in expected]
assert list(parser.transform(iter([text]))) == [[a] for a in expected]
def test_multiple_items() -> None:
"""Test that a string with multiple comma-separated items is parsed to a list."""
parser = CommaSeparatedListOutputParser()
text = "foo, bar, baz"
expected = ["foo", "bar", "baz"]
assert parser.parse(text) == expected
assert add(parser.transform(t for t in text)) == expected
assert list(parser.transform(t for t in text)) == [[a] for a in expected]
assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [
[a] for a in expected
]
assert list(
parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" ")))
) == [[a] for a in expected]
assert list(parser.transform(iter([text]))) == [[a] for a in expected]
def test_numbered_list() -> None:
parser = NumberedListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line. "
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
)
text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry"
text3 = "No items in the list."
for text, expected in [
(text1, ["foo", "bar", "baz"]),
(text2, ["apple", "banana", "cherry"]),
(text3, []),
]:
expectedlist = [[a] for a in cast(List[str], expected)]
assert parser.parse(text) == expected
assert add(parser.transform(t for t in text)) == (expected or None)
assert list(parser.transform(t for t in text)) == expectedlist
assert (
list(parser.transform(t for t in text.splitlines(keepends=True)))
== expectedlist
)
assert (
list(
parser.transform(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
== expectedlist
)
assert list(parser.transform(iter([text]))) == expectedlist
def test_markdown_list() -> None:
parser = MarkdownListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line."
"For example: \n- foo\n- bar\n- baz"
)
text2 = "Items:\n- apple\n- banana\n- cherry"
text3 = "No items in the list."
for text, expected in [
(text1, ["foo", "bar", "baz"]),
(text2, ["apple", "banana", "cherry"]),
(text3, []),
]:
expectedlist = [[a] for a in cast(List[str], expected)]
assert parser.parse(text) == expected
assert add(parser.transform(t for t in text)) == (expected or None)
assert list(parser.transform(t for t in text)) == expectedlist
assert (
list(parser.transform(t for t in text.splitlines(keepends=True)))
== expectedlist
)
assert (
list(
parser.transform(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
== expectedlist
)
assert list(parser.transform(iter([text]))) == expectedlist
T = TypeVar("T")
async def aiter_from_iter(iterable: Iterable[T]) -> AsyncIterator[T]:
for item in iterable:
yield item
async def test_single_item_async() -> None:
"""Test that a string with a single item is parsed to a list with that item."""
parser = CommaSeparatedListOutputParser()
text = "foo"
expected = ["foo"]
assert await parser.aparse(text) == expected
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == expected
assert [a async for a in parser.atransform(aiter_from_iter(t for t in text))] == [
[a] for a in expected
]
assert [
a
async for a in parser.atransform(
aiter_from_iter(t for t in text.splitlines(keepends=True))
)
] == [[a] for a in expected]
assert [
a
async for a in parser.atransform(
aiter_from_iter(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
] == [[a] for a in expected]
assert [a async for a in parser.atransform(aiter_from_iter([text]))] == [
[a] for a in expected
]
async def test_multiple_items_async() -> None:
"""Test that a string with multiple comma-separated items is parsed to a list."""
parser = CommaSeparatedListOutputParser()
text = "foo, bar, baz"
expected = ["foo", "bar", "baz"]
assert await parser.aparse(text) == expected
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == expected
assert [a async for a in parser.atransform(aiter_from_iter(t for t in text))] == [
[a] for a in expected
]
assert [
a
async for a in parser.atransform(
aiter_from_iter(t for t in text.splitlines(keepends=True))
)
] == [[a] for a in expected]
assert [
a
async for a in parser.atransform(
aiter_from_iter(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
] == [[a] for a in expected]
assert [a async for a in parser.atransform(aiter_from_iter([text]))] == [
[a] for a in expected
]
async def test_numbered_list_async() -> None:
parser = NumberedListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line. "
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
)
text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry"
text3 = "No items in the list."
for text, expected in [
(text1, ["foo", "bar", "baz"]),
(text2, ["apple", "banana", "cherry"]),
(text3, []),
]:
expectedlist = [[a] for a in cast(List[str], expected)]
assert await parser.aparse(text) == expected
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == (
expected or None
)
assert [
a async for a in parser.atransform(aiter_from_iter(t for t in text))
] == expectedlist
assert [
a
async for a in parser.atransform(
aiter_from_iter(t for t in text.splitlines(keepends=True))
)
] == expectedlist
assert [
a
async for a in parser.atransform(
aiter_from_iter(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
] == expectedlist
assert [
a async for a in parser.atransform(aiter_from_iter([text]))
] == expectedlist
async def test_markdown_list_async() -> None:
parser = MarkdownListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line."
"For example: \n- foo\n- bar\n- baz"
)
text2 = "Items:\n- apple\n- banana\n- cherry"
text3 = "No items in the list."
for text, expected in [
(text1, ["foo", "bar", "baz"]),
(text2, ["apple", "banana", "cherry"]),
(text3, []),
]:
expectedlist = [[a] for a in cast(List[str], expected)]
assert await parser.aparse(text) == expected
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == (
expected or None
)
assert [
a async for a in parser.atransform(aiter_from_iter(t for t in text))
] == expectedlist
assert [
a
async for a in parser.atransform(
aiter_from_iter(t for t in text.splitlines(keepends=True))
)
] == expectedlist
assert [
a
async for a in parser.atransform(
aiter_from_iter(
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
)
)
] == expectedlist
assert [
a async for a in parser.atransform(aiter_from_iter([text]))
] == expectedlist

View File

@ -1,49 +0,0 @@
from langchain.output_parsers.list import (
CommaSeparatedListOutputParser,
MarkdownListOutputParser,
NumberedListOutputParser,
)
def test_single_item() -> None:
"""Test that a string with a single item is parsed to a list with that item."""
parser = CommaSeparatedListOutputParser()
assert parser.parse("foo") == ["foo"]
def test_multiple_items() -> None:
"""Test that a string with multiple comma-separated items is parsed to a list."""
parser = CommaSeparatedListOutputParser()
assert parser.parse("foo, bar, baz") == ["foo", "bar", "baz"]
def test_numbered_list() -> None:
parser = NumberedListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line. "
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
)
text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry"
text3 = "No items in the list."
assert parser.parse(text1) == ["foo", "bar", "baz"]
assert parser.parse(text2) == ["apple", "banana", "cherry"]
assert parser.parse(text3) == []
def test_markdown_list() -> None:
parser = MarkdownListOutputParser()
text1 = (
"Your response should be a numbered list with each item on a new line."
"For example: \n- foo\n- bar\n- baz"
)
text2 = "Items:\n- apple\n- banana\n- cherry"
text3 = "No items in the list."
assert parser.parse(text1) == ["foo", "bar", "baz"]
assert parser.parse(text2) == ["apple", "banana", "cherry"]
assert parser.parse(text3) == []