mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
Implement streaming for all list output parsers (#14981)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
b471166df7
commit
63e512b680
@ -2,12 +2,25 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
from collections import deque
|
||||
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
|
||||
|
||||
from langchain_core.output_parsers.base import BaseOutputParser
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ListOutputParser(BaseOutputParser[List[str]]):
|
||||
def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
|
||||
"""Drop the last n elements of an iterator."""
|
||||
buffer: Deque[T] = deque()
|
||||
for item in iter:
|
||||
buffer.append(item)
|
||||
if len(buffer) > n:
|
||||
yield buffer.popleft()
|
||||
|
||||
|
||||
class ListOutputParser(BaseTransformOutputParser[List[str]]):
|
||||
"""Parse the output of an LLM call to a list."""
|
||||
|
||||
@property
|
||||
@ -18,6 +31,74 @@ class ListOutputParser(BaseOutputParser[List[str]]):
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
def parse_iter(self, text: str) -> Iterator[re.Match]:
|
||||
"""Parse the output of an LLM call."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
) -> Iterator[List[str]]:
|
||||
buffer = ""
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
# extract text
|
||||
chunk_content = chunk.content
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
chunk = chunk_content
|
||||
# add current chunk to buffer
|
||||
buffer += chunk
|
||||
# parse buffer into a list of parts
|
||||
try:
|
||||
done_idx = 0
|
||||
# yield only complete parts
|
||||
for m in droplastn(self.parse_iter(buffer), 1):
|
||||
done_idx = m.end()
|
||||
yield [m.group(1)]
|
||||
buffer = buffer[done_idx:]
|
||||
except NotImplementedError:
|
||||
parts = self.parse(buffer)
|
||||
# yield only complete parts
|
||||
if len(parts) > 1:
|
||||
for part in parts[:-1]:
|
||||
yield [part]
|
||||
buffer = parts[-1]
|
||||
# yield the last part
|
||||
for part in self.parse(buffer):
|
||||
yield [part]
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[List[str]]:
|
||||
buffer = ""
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
# extract text
|
||||
chunk_content = chunk.content
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
chunk = chunk_content
|
||||
# add current chunk to buffer
|
||||
buffer += chunk
|
||||
# parse buffer into a list of parts
|
||||
try:
|
||||
done_idx = 0
|
||||
# yield only complete parts
|
||||
for m in droplastn(self.parse_iter(buffer), 1):
|
||||
done_idx = m.end()
|
||||
yield [m.group(1)]
|
||||
buffer = buffer[done_idx:]
|
||||
except NotImplementedError:
|
||||
parts = self.parse(buffer)
|
||||
# yield only complete parts
|
||||
if len(parts) > 1:
|
||||
for part in parts[:-1]:
|
||||
yield [part]
|
||||
buffer = parts[-1]
|
||||
# yield the last part
|
||||
for part in self.parse(buffer):
|
||||
yield [part]
|
||||
|
||||
|
||||
class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
"""Parse the output of an LLM call to a comma-separated list."""
|
||||
@ -49,6 +130,8 @@ class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
class NumberedListOutputParser(ListOutputParser):
|
||||
"""Parse a numbered list."""
|
||||
|
||||
pattern = r"\d+\.\s([^\n]+)"
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return (
|
||||
"Your response should be a numbered list with each item on a new line. "
|
||||
@ -57,11 +140,11 @@ class NumberedListOutputParser(ListOutputParser):
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
pattern = r"\d+\.\s([^\n]+)"
|
||||
return re.findall(self.pattern, text)
|
||||
|
||||
# Extract the text of each item
|
||||
matches = re.findall(pattern, text)
|
||||
return matches
|
||||
def parse_iter(self, text: str) -> Iterator[re.Match]:
|
||||
"""Parse the output of an LLM call."""
|
||||
return re.finditer(self.pattern, text)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
@ -71,13 +154,18 @@ class NumberedListOutputParser(ListOutputParser):
|
||||
class MarkdownListOutputParser(ListOutputParser):
|
||||
"""Parse a markdown list."""
|
||||
|
||||
pattern = r"-\s([^\n]+)"
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
pattern = r"-\s([^\n]+)"
|
||||
return re.findall(pattern, text)
|
||||
return re.findall(self.pattern, text)
|
||||
|
||||
def parse_iter(self, text: str) -> Iterator[re.Match]:
|
||||
"""Parse the output of an LLM call."""
|
||||
return re.finditer(self.pattern, text)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
|
268
libs/core/tests/unit_tests/output_parsers/test_list_parser.py
Normal file
268
libs/core/tests/unit_tests/output_parsers/test_list_parser.py
Normal file
@ -0,0 +1,268 @@
|
||||
from typing import AsyncIterator, Iterable, List, TypeVar, cast
|
||||
|
||||
from langchain_core.output_parsers.list import (
|
||||
CommaSeparatedListOutputParser,
|
||||
MarkdownListOutputParser,
|
||||
NumberedListOutputParser,
|
||||
)
|
||||
from langchain_core.runnables.utils import aadd, add
|
||||
|
||||
|
||||
def test_single_item() -> None:
|
||||
"""Test that a string with a single item is parsed to a list with that item."""
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
text = "foo"
|
||||
expected = ["foo"]
|
||||
|
||||
assert parser.parse(text) == expected
|
||||
assert add(parser.transform(t for t in text)) == expected
|
||||
assert list(parser.transform(t for t in text)) == [[a] for a in expected]
|
||||
assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [
|
||||
[a] for a in expected
|
||||
]
|
||||
assert list(
|
||||
parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" ")))
|
||||
) == [[a] for a in expected]
|
||||
assert list(parser.transform(iter([text]))) == [[a] for a in expected]
|
||||
|
||||
|
||||
def test_multiple_items() -> None:
|
||||
"""Test that a string with multiple comma-separated items is parsed to a list."""
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
text = "foo, bar, baz"
|
||||
expected = ["foo", "bar", "baz"]
|
||||
|
||||
assert parser.parse(text) == expected
|
||||
assert add(parser.transform(t for t in text)) == expected
|
||||
assert list(parser.transform(t for t in text)) == [[a] for a in expected]
|
||||
assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [
|
||||
[a] for a in expected
|
||||
]
|
||||
assert list(
|
||||
parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" ")))
|
||||
) == [[a] for a in expected]
|
||||
assert list(parser.transform(iter([text]))) == [[a] for a in expected]
|
||||
|
||||
|
||||
def test_numbered_list() -> None:
|
||||
parser = NumberedListOutputParser()
|
||||
text1 = (
|
||||
"Your response should be a numbered list with each item on a new line. "
|
||||
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
|
||||
)
|
||||
|
||||
text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry"
|
||||
|
||||
text3 = "No items in the list."
|
||||
|
||||
for text, expected in [
|
||||
(text1, ["foo", "bar", "baz"]),
|
||||
(text2, ["apple", "banana", "cherry"]),
|
||||
(text3, []),
|
||||
]:
|
||||
expectedlist = [[a] for a in cast(List[str], expected)]
|
||||
assert parser.parse(text) == expected
|
||||
assert add(parser.transform(t for t in text)) == (expected or None)
|
||||
assert list(parser.transform(t for t in text)) == expectedlist
|
||||
assert (
|
||||
list(parser.transform(t for t in text.splitlines(keepends=True)))
|
||||
== expectedlist
|
||||
)
|
||||
assert (
|
||||
list(
|
||||
parser.transform(
|
||||
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
|
||||
)
|
||||
)
|
||||
== expectedlist
|
||||
)
|
||||
assert list(parser.transform(iter([text]))) == expectedlist
|
||||
|
||||
|
||||
def test_markdown_list() -> None:
|
||||
parser = MarkdownListOutputParser()
|
||||
text1 = (
|
||||
"Your response should be a numbered list with each item on a new line."
|
||||
"For example: \n- foo\n- bar\n- baz"
|
||||
)
|
||||
|
||||
text2 = "Items:\n- apple\n- banana\n- cherry"
|
||||
|
||||
text3 = "No items in the list."
|
||||
|
||||
for text, expected in [
|
||||
(text1, ["foo", "bar", "baz"]),
|
||||
(text2, ["apple", "banana", "cherry"]),
|
||||
(text3, []),
|
||||
]:
|
||||
expectedlist = [[a] for a in cast(List[str], expected)]
|
||||
assert parser.parse(text) == expected
|
||||
assert add(parser.transform(t for t in text)) == (expected or None)
|
||||
assert list(parser.transform(t for t in text)) == expectedlist
|
||||
assert (
|
||||
list(parser.transform(t for t in text.splitlines(keepends=True)))
|
||||
== expectedlist
|
||||
)
|
||||
assert (
|
||||
list(
|
||||
parser.transform(
|
||||
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
|
||||
)
|
||||
)
|
||||
== expectedlist
|
||||
)
|
||||
assert list(parser.transform(iter([text]))) == expectedlist
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def aiter_from_iter(iterable: Iterable[T]) -> AsyncIterator[T]:
|
||||
for item in iterable:
|
||||
yield item
|
||||
|
||||
|
||||
async def test_single_item_async() -> None:
|
||||
"""Test that a string with a single item is parsed to a list with that item."""
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
text = "foo"
|
||||
expected = ["foo"]
|
||||
|
||||
assert await parser.aparse(text) == expected
|
||||
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == expected
|
||||
assert [a async for a in parser.atransform(aiter_from_iter(t for t in text))] == [
|
||||
[a] for a in expected
|
||||
]
|
||||
assert [
|
||||
a
|
||||
async for a in parser.atransform(
|
||||
aiter_from_iter(t for t in text.splitlines(keepends=True))
|
||||
)
|
||||
] == [[a] for a in expected]
|
||||
assert [
|
||||
a
|
||||
async for a in parser.atransform(
|
||||
aiter_from_iter(
|
||||
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
|
||||
)
|
||||
)
|
||||
] == [[a] for a in expected]
|
||||
assert [a async for a in parser.atransform(aiter_from_iter([text]))] == [
|
||||
[a] for a in expected
|
||||
]
|
||||
|
||||
|
||||
async def test_multiple_items_async() -> None:
|
||||
"""Test that a string with multiple comma-separated items is parsed to a list."""
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
text = "foo, bar, baz"
|
||||
expected = ["foo", "bar", "baz"]
|
||||
|
||||
assert await parser.aparse(text) == expected
|
||||
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == expected
|
||||
assert [a async for a in parser.atransform(aiter_from_iter(t for t in text))] == [
|
||||
[a] for a in expected
|
||||
]
|
||||
assert [
|
||||
a
|
||||
async for a in parser.atransform(
|
||||
aiter_from_iter(t for t in text.splitlines(keepends=True))
|
||||
)
|
||||
] == [[a] for a in expected]
|
||||
assert [
|
||||
a
|
||||
async for a in parser.atransform(
|
||||
aiter_from_iter(
|
||||
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
|
||||
)
|
||||
)
|
||||
] == [[a] for a in expected]
|
||||
assert [a async for a in parser.atransform(aiter_from_iter([text]))] == [
|
||||
[a] for a in expected
|
||||
]
|
||||
|
||||
|
||||
async def test_numbered_list_async() -> None:
|
||||
parser = NumberedListOutputParser()
|
||||
text1 = (
|
||||
"Your response should be a numbered list with each item on a new line. "
|
||||
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
|
||||
)
|
||||
|
||||
text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry"
|
||||
|
||||
text3 = "No items in the list."
|
||||
|
||||
for text, expected in [
|
||||
(text1, ["foo", "bar", "baz"]),
|
||||
(text2, ["apple", "banana", "cherry"]),
|
||||
(text3, []),
|
||||
]:
|
||||
expectedlist = [[a] for a in cast(List[str], expected)]
|
||||
assert await parser.aparse(text) == expected
|
||||
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == (
|
||||
expected or None
|
||||
)
|
||||
assert [
|
||||
a async for a in parser.atransform(aiter_from_iter(t for t in text))
|
||||
] == expectedlist
|
||||
assert [
|
||||
a
|
||||
async for a in parser.atransform(
|
||||
aiter_from_iter(t for t in text.splitlines(keepends=True))
|
||||
)
|
||||
] == expectedlist
|
||||
assert [
|
||||
a
|
||||
async for a in parser.atransform(
|
||||
aiter_from_iter(
|
||||
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
|
||||
)
|
||||
)
|
||||
] == expectedlist
|
||||
assert [
|
||||
a async for a in parser.atransform(aiter_from_iter([text]))
|
||||
] == expectedlist
|
||||
|
||||
|
||||
async def test_markdown_list_async() -> None:
|
||||
parser = MarkdownListOutputParser()
|
||||
text1 = (
|
||||
"Your response should be a numbered list with each item on a new line."
|
||||
"For example: \n- foo\n- bar\n- baz"
|
||||
)
|
||||
|
||||
text2 = "Items:\n- apple\n- banana\n- cherry"
|
||||
|
||||
text3 = "No items in the list."
|
||||
|
||||
for text, expected in [
|
||||
(text1, ["foo", "bar", "baz"]),
|
||||
(text2, ["apple", "banana", "cherry"]),
|
||||
(text3, []),
|
||||
]:
|
||||
expectedlist = [[a] for a in cast(List[str], expected)]
|
||||
assert await parser.aparse(text) == expected
|
||||
assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == (
|
||||
expected or None
|
||||
)
|
||||
assert [
|
||||
a async for a in parser.atransform(aiter_from_iter(t for t in text))
|
||||
] == expectedlist
|
||||
assert [
|
||||
a
|
||||
async for a in parser.atransform(
|
||||
aiter_from_iter(t for t in text.splitlines(keepends=True))
|
||||
)
|
||||
] == expectedlist
|
||||
assert [
|
||||
a
|
||||
async for a in parser.atransform(
|
||||
aiter_from_iter(
|
||||
" " + t if i > 0 else t for i, t in enumerate(text.split(" "))
|
||||
)
|
||||
)
|
||||
] == expectedlist
|
||||
assert [
|
||||
a async for a in parser.atransform(aiter_from_iter([text]))
|
||||
] == expectedlist
|
@ -1,49 +0,0 @@
|
||||
from langchain.output_parsers.list import (
|
||||
CommaSeparatedListOutputParser,
|
||||
MarkdownListOutputParser,
|
||||
NumberedListOutputParser,
|
||||
)
|
||||
|
||||
|
||||
def test_single_item() -> None:
|
||||
"""Test that a string with a single item is parsed to a list with that item."""
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
assert parser.parse("foo") == ["foo"]
|
||||
|
||||
|
||||
def test_multiple_items() -> None:
|
||||
"""Test that a string with multiple comma-separated items is parsed to a list."""
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
assert parser.parse("foo, bar, baz") == ["foo", "bar", "baz"]
|
||||
|
||||
|
||||
def test_numbered_list() -> None:
|
||||
parser = NumberedListOutputParser()
|
||||
text1 = (
|
||||
"Your response should be a numbered list with each item on a new line. "
|
||||
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
|
||||
)
|
||||
|
||||
text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry"
|
||||
|
||||
text3 = "No items in the list."
|
||||
|
||||
assert parser.parse(text1) == ["foo", "bar", "baz"]
|
||||
assert parser.parse(text2) == ["apple", "banana", "cherry"]
|
||||
assert parser.parse(text3) == []
|
||||
|
||||
|
||||
def test_markdown_list() -> None:
|
||||
parser = MarkdownListOutputParser()
|
||||
text1 = (
|
||||
"Your response should be a numbered list with each item on a new line."
|
||||
"For example: \n- foo\n- bar\n- baz"
|
||||
)
|
||||
|
||||
text2 = "Items:\n- apple\n- banana\n- cherry"
|
||||
|
||||
text3 = "No items in the list."
|
||||
|
||||
assert parser.parse(text1) == ["foo", "bar", "baz"]
|
||||
assert parser.parse(text2) == ["apple", "banana", "cherry"]
|
||||
assert parser.parse(text3) == []
|
Loading…
Reference in New Issue
Block a user