diff --git a/libs/langchain/langchain/output_parsers/list.py b/libs/langchain/langchain/output_parsers/list.py index 92850fc31c2..6680cb908aa 100644 --- a/libs/langchain/langchain/output_parsers/list.py +++ b/libs/langchain/langchain/output_parsers/list.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from abc import abstractmethod from typing import List @@ -34,3 +35,29 @@ class CommaSeparatedListOutputParser(ListOutputParser): def parse(self, text: str) -> List[str]: """Parse the output of an LLM call.""" return text.strip().split(", ") + + @property + def _type(self) -> str: + return "comma-separated-list" + + +class NumberedListOutputParser(ListOutputParser): + """Parse a numbered list.""" + + def get_format_instructions(self) -> str: + return ( + "Your response should be a numbered list with each item on a new line. " + "For example: \n\n1. foo\n\n2. bar\n\n3. baz" + ) + + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + pattern = r"\d+\.\s([^\n]+)" + + # Extract the text of each item + matches = re.findall(pattern, text) + return matches + + @property + def _type(self) -> str: + return "numbered-list" diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py index 84be4db9464..a25bbc38bd9 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py @@ -1,4 +1,7 @@ -from langchain.output_parsers.list import CommaSeparatedListOutputParser +from langchain.output_parsers.list import ( + CommaSeparatedListOutputParser, + NumberedListOutputParser, +) def test_single_item() -> None: @@ -11,3 +14,19 @@ def test_multiple_items() -> None: """Test that a string with multiple comma-separated items is parsed to a list.""" parser = CommaSeparatedListOutputParser() assert parser.parse("foo, bar, baz") == ["foo", "bar", "baz"] + + +def test_numbered_list() -> None: + parser = NumberedListOutputParser() + text1 = ( + "Your response should be a numbered list with each item on a new line. " + "For example: \n\n1. foo\n\n2. bar\n\n3. baz" + ) + + text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry" + + text3 = "No items in the list." + + assert parser.parse(text1) == ["foo", "bar", "baz"] + assert parser.parse(text2) == ["apple", "banana", "cherry"] + assert parser.parse(text3) == []