From 63e512b680c65506c730dd6a7f9c8d488dcf0678 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 21 Dec 2023 11:30:35 -0800 Subject: [PATCH] Implement streaming for all list output parsers (#14981) --- .../langchain_core/output_parsers/list.py | 106 ++++++- .../output_parsers/test_list_parser.py | 268 ++++++++++++++++++ .../output_parsers/test_list_parser.py | 49 ---- 3 files changed, 365 insertions(+), 58 deletions(-) create mode 100644 libs/core/tests/unit_tests/output_parsers/test_list_parser.py delete mode 100644 libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index 1ad75b24bb4..ac9f6a5f2bf 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -2,12 +2,25 @@ from __future__ import annotations import re from abc import abstractmethod -from typing import List +from collections import deque +from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union -from langchain_core.output_parsers.base import BaseOutputParser +from langchain_core.messages import BaseMessage +from langchain_core.output_parsers.transform import BaseTransformOutputParser + +T = TypeVar("T") -class ListOutputParser(BaseOutputParser[List[str]]): +def droplastn(iter: Iterator[T], n: int) -> Iterator[T]: + """Drop the last n elements of an iterator.""" + buffer: Deque[T] = deque() + for item in iter: + buffer.append(item) + if len(buffer) > n: + yield buffer.popleft() + + +class ListOutputParser(BaseTransformOutputParser[List[str]]): """Parse the output of an LLM call to a list.""" @property @@ -18,6 +31,74 @@ class ListOutputParser(BaseOutputParser[List[str]]): def parse(self, text: str) -> List[str]: """Parse the output of an LLM call.""" + def parse_iter(self, text: str) -> Iterator[re.Match]: + """Parse the output of an LLM call.""" + raise NotImplementedError + + def _transform( + self, input: Iterator[Union[str, BaseMessage]] + ) -> Iterator[List[str]]: + buffer = "" + for chunk in input: + if isinstance(chunk, BaseMessage): + # extract text + chunk_content = chunk.content + if not isinstance(chunk_content, str): + continue + chunk = chunk_content + # add current chunk to buffer + buffer += chunk + # parse buffer into a list of parts + try: + done_idx = 0 + # yield only complete parts + for m in droplastn(self.parse_iter(buffer), 1): + done_idx = m.end() + yield [m.group(1)] + buffer = buffer[done_idx:] + except NotImplementedError: + parts = self.parse(buffer) + # yield only complete parts + if len(parts) > 1: + for part in parts[:-1]: + yield [part] + buffer = parts[-1] + # yield the last part + for part in self.parse(buffer): + yield [part] + + async def _atransform( + self, input: AsyncIterator[Union[str, BaseMessage]] + ) -> AsyncIterator[List[str]]: + buffer = "" + async for chunk in input: + if isinstance(chunk, BaseMessage): + # extract text + chunk_content = chunk.content + if not isinstance(chunk_content, str): + continue + chunk = chunk_content + # add current chunk to buffer + buffer += chunk + # parse buffer into a list of parts + try: + done_idx = 0 + # yield only complete parts + for m in droplastn(self.parse_iter(buffer), 1): + done_idx = m.end() + yield [m.group(1)] + buffer = buffer[done_idx:] + except NotImplementedError: + parts = self.parse(buffer) + # yield only complete parts + if len(parts) > 1: + for part in parts[:-1]: + yield [part] + buffer = parts[-1] + # yield the last part + for part in self.parse(buffer): + yield [part] + class CommaSeparatedListOutputParser(ListOutputParser): """Parse the output of an LLM call to a comma-separated list.""" @@ -49,6 +130,8 @@ class CommaSeparatedListOutputParser(ListOutputParser): class NumberedListOutputParser(ListOutputParser): """Parse a numbered list.""" + pattern = r"\d+\.\s([^\n]+)" + def get_format_instructions(self) -> str: return ( "Your response should be a numbered list with each item on a new line. " @@ -57,11 +140,11 @@ class NumberedListOutputParser(ListOutputParser): def parse(self, text: str) -> List[str]: """Parse the output of an LLM call.""" - pattern = r"\d+\.\s([^\n]+)" + return re.findall(self.pattern, text) - # Extract the text of each item - matches = re.findall(pattern, text) - return matches + def parse_iter(self, text: str) -> Iterator[re.Match]: + """Parse the output of an LLM call.""" + return re.finditer(self.pattern, text) @property def _type(self) -> str: @@ -71,13 +154,18 @@ class NumberedListOutputParser(ListOutputParser): class MarkdownListOutputParser(ListOutputParser): """Parse a markdown list.""" + pattern = r"-\s([^\n]+)" + def get_format_instructions(self) -> str: return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`" def parse(self, text: str) -> List[str]: """Parse the output of an LLM call.""" - pattern = r"-\s([^\n]+)" - return re.findall(pattern, text) + return re.findall(self.pattern, text) + + def parse_iter(self, text: str) -> Iterator[re.Match]: + """Parse the output of an LLM call.""" + return re.finditer(self.pattern, text) @property def _type(self) -> str: diff --git a/libs/core/tests/unit_tests/output_parsers/test_list_parser.py b/libs/core/tests/unit_tests/output_parsers/test_list_parser.py new file mode 100644 index 00000000000..4e2f506264b --- /dev/null +++ b/libs/core/tests/unit_tests/output_parsers/test_list_parser.py @@ -0,0 +1,268 @@ +from typing import AsyncIterator, Iterable, List, TypeVar, cast + +from langchain_core.output_parsers.list import ( + CommaSeparatedListOutputParser, + MarkdownListOutputParser, + NumberedListOutputParser, +) +from langchain_core.runnables.utils import aadd, add + + +def test_single_item() -> None: + """Test that a string with a single item is parsed to a list with that item.""" + parser = CommaSeparatedListOutputParser() + text = "foo" + expected = ["foo"] + + assert parser.parse(text) == expected + assert add(parser.transform(t for t in text)) == expected + assert list(parser.transform(t for t in text)) == [[a] for a in expected] + assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [ + [a] for a in expected + ] + assert list( + parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" "))) + ) == [[a] for a in expected] + assert list(parser.transform(iter([text]))) == [[a] for a in expected] + + +def test_multiple_items() -> None: + """Test that a string with multiple comma-separated items is parsed to a list.""" + parser = CommaSeparatedListOutputParser() + text = "foo, bar, baz" + expected = ["foo", "bar", "baz"] + + assert parser.parse(text) == expected + assert add(parser.transform(t for t in text)) == expected + assert list(parser.transform(t for t in text)) == [[a] for a in expected] + assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [ + [a] for a in expected + ] + assert list( + parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" "))) + ) == [[a] for a in expected] + assert list(parser.transform(iter([text]))) == [[a] for a in expected] + + +def test_numbered_list() -> None: + parser = NumberedListOutputParser() + text1 = ( + "Your response should be a numbered list with each item on a new line. " + "For example: \n\n1. foo\n\n2. bar\n\n3. baz" + ) + + text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry" + + text3 = "No items in the list." + + for text, expected in [ + (text1, ["foo", "bar", "baz"]), + (text2, ["apple", "banana", "cherry"]), + (text3, []), + ]: + expectedlist = [[a] for a in cast(List[str], expected)] + assert parser.parse(text) == expected + assert add(parser.transform(t for t in text)) == (expected or None) + assert list(parser.transform(t for t in text)) == expectedlist + assert ( + list(parser.transform(t for t in text.splitlines(keepends=True))) + == expectedlist + ) + assert ( + list( + parser.transform( + " " + t if i > 0 else t for i, t in enumerate(text.split(" ")) + ) + ) + == expectedlist + ) + assert list(parser.transform(iter([text]))) == expectedlist + + +def test_markdown_list() -> None: + parser = MarkdownListOutputParser() + text1 = ( + "Your response should be a numbered list with each item on a new line." + "For example: \n- foo\n- bar\n- baz" + ) + + text2 = "Items:\n- apple\n- banana\n- cherry" + + text3 = "No items in the list." + + for text, expected in [ + (text1, ["foo", "bar", "baz"]), + (text2, ["apple", "banana", "cherry"]), + (text3, []), + ]: + expectedlist = [[a] for a in cast(List[str], expected)] + assert parser.parse(text) == expected + assert add(parser.transform(t for t in text)) == (expected or None) + assert list(parser.transform(t for t in text)) == expectedlist + assert ( + list(parser.transform(t for t in text.splitlines(keepends=True))) + == expectedlist + ) + assert ( + list( + parser.transform( + " " + t if i > 0 else t for i, t in enumerate(text.split(" ")) + ) + ) + == expectedlist + ) + assert list(parser.transform(iter([text]))) == expectedlist + + +T = TypeVar("T") + + +async def aiter_from_iter(iterable: Iterable[T]) -> AsyncIterator[T]: + for item in iterable: + yield item + + +async def test_single_item_async() -> None: + """Test that a string with a single item is parsed to a list with that item.""" + parser = CommaSeparatedListOutputParser() + text = "foo" + expected = ["foo"] + + assert await parser.aparse(text) == expected + assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == expected + assert [a async for a in parser.atransform(aiter_from_iter(t for t in text))] == [ + [a] for a in expected + ] + assert [ + a + async for a in parser.atransform( + aiter_from_iter(t for t in text.splitlines(keepends=True)) + ) + ] == [[a] for a in expected] + assert [ + a + async for a in parser.atransform( + aiter_from_iter( + " " + t if i > 0 else t for i, t in enumerate(text.split(" ")) + ) + ) + ] == [[a] for a in expected] + assert [a async for a in parser.atransform(aiter_from_iter([text]))] == [ + [a] for a in expected + ] + + +async def test_multiple_items_async() -> None: + """Test that a string with multiple comma-separated items is parsed to a list.""" + parser = CommaSeparatedListOutputParser() + text = "foo, bar, baz" + expected = ["foo", "bar", "baz"] + + assert await parser.aparse(text) == expected + assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == expected + assert [a async for a in parser.atransform(aiter_from_iter(t for t in text))] == [ + [a] for a in expected + ] + assert [ + a + async for a in parser.atransform( + aiter_from_iter(t for t in text.splitlines(keepends=True)) + ) + ] == [[a] for a in expected] + assert [ + a + async for a in parser.atransform( + aiter_from_iter( + " " + t if i > 0 else t for i, t in enumerate(text.split(" ")) + ) + ) + ] == [[a] for a in expected] + assert [a async for a in parser.atransform(aiter_from_iter([text]))] == [ + [a] for a in expected + ] + + +async def test_numbered_list_async() -> None: + parser = NumberedListOutputParser() + text1 = ( + "Your response should be a numbered list with each item on a new line. " + "For example: \n\n1. foo\n\n2. bar\n\n3. baz" + ) + + text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry" + + text3 = "No items in the list." + + for text, expected in [ + (text1, ["foo", "bar", "baz"]), + (text2, ["apple", "banana", "cherry"]), + (text3, []), + ]: + expectedlist = [[a] for a in cast(List[str], expected)] + assert await parser.aparse(text) == expected + assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == ( + expected or None + ) + assert [ + a async for a in parser.atransform(aiter_from_iter(t for t in text)) + ] == expectedlist + assert [ + a + async for a in parser.atransform( + aiter_from_iter(t for t in text.splitlines(keepends=True)) + ) + ] == expectedlist + assert [ + a + async for a in parser.atransform( + aiter_from_iter( + " " + t if i > 0 else t for i, t in enumerate(text.split(" ")) + ) + ) + ] == expectedlist + assert [ + a async for a in parser.atransform(aiter_from_iter([text])) + ] == expectedlist + + +async def test_markdown_list_async() -> None: + parser = MarkdownListOutputParser() + text1 = ( + "Your response should be a numbered list with each item on a new line." + "For example: \n- foo\n- bar\n- baz" + ) + + text2 = "Items:\n- apple\n- banana\n- cherry" + + text3 = "No items in the list." + + for text, expected in [ + (text1, ["foo", "bar", "baz"]), + (text2, ["apple", "banana", "cherry"]), + (text3, []), + ]: + expectedlist = [[a] for a in cast(List[str], expected)] + assert await parser.aparse(text) == expected + assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == ( + expected or None + ) + assert [ + a async for a in parser.atransform(aiter_from_iter(t for t in text)) + ] == expectedlist + assert [ + a + async for a in parser.atransform( + aiter_from_iter(t for t in text.splitlines(keepends=True)) + ) + ] == expectedlist + assert [ + a + async for a in parser.atransform( + aiter_from_iter( + " " + t if i > 0 else t for i, t in enumerate(text.split(" ")) + ) + ) + ] == expectedlist + assert [ + a async for a in parser.atransform(aiter_from_iter([text])) + ] == expectedlist diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py deleted file mode 100644 index c85a31b9578..00000000000 --- a/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py +++ /dev/null @@ -1,49 +0,0 @@ -from langchain.output_parsers.list import ( - CommaSeparatedListOutputParser, - MarkdownListOutputParser, - NumberedListOutputParser, -) - - -def test_single_item() -> None: - """Test that a string with a single item is parsed to a list with that item.""" - parser = CommaSeparatedListOutputParser() - assert parser.parse("foo") == ["foo"] - - -def test_multiple_items() -> None: - """Test that a string with multiple comma-separated items is parsed to a list.""" - parser = CommaSeparatedListOutputParser() - assert parser.parse("foo, bar, baz") == ["foo", "bar", "baz"] - - -def test_numbered_list() -> None: - parser = NumberedListOutputParser() - text1 = ( - "Your response should be a numbered list with each item on a new line. " - "For example: \n\n1. foo\n\n2. bar\n\n3. baz" - ) - - text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry" - - text3 = "No items in the list." - - assert parser.parse(text1) == ["foo", "bar", "baz"] - assert parser.parse(text2) == ["apple", "banana", "cherry"] - assert parser.parse(text3) == [] - - -def test_markdown_list() -> None: - parser = MarkdownListOutputParser() - text1 = ( - "Your response should be a numbered list with each item on a new line." - "For example: \n- foo\n- bar\n- baz" - ) - - text2 = "Items:\n- apple\n- banana\n- cherry" - - text3 = "No items in the list." - - assert parser.parse(text1) == ["foo", "bar", "baz"] - assert parser.parse(text2) == ["apple", "banana", "cherry"] - assert parser.parse(text3) == []