diff --git a/libs/langchain/langchain/output_parsers/__init__.py b/libs/langchain/langchain/output_parsers/__init__.py index a98ad9caec9..2c946a34ac5 100644 --- a/libs/langchain/langchain/output_parsers/__init__.py +++ b/libs/langchain/langchain/output_parsers/__init__.py @@ -20,6 +20,7 @@ from langchain.output_parsers.fix import OutputFixingParser from langchain.output_parsers.list import ( CommaSeparatedListOutputParser, ListOutputParser, + MarkdownListOutputParser, NumberedListOutputParser, ) from langchain.output_parsers.pydantic import PydanticOutputParser @@ -38,6 +39,7 @@ __all__ = [ "EnumOutputParser", "GuardrailsOutputParser", "ListOutputParser", + "MarkdownListOutputParser", "NumberedListOutputParser", "OutputFixingParser", "PydanticOutputParser", diff --git a/libs/langchain/langchain/output_parsers/list.py b/libs/langchain/langchain/output_parsers/list.py index 3ee9604366b..a1b955ef0a4 100644 --- a/libs/langchain/langchain/output_parsers/list.py +++ b/libs/langchain/langchain/output_parsers/list.py @@ -61,3 +61,19 @@ class NumberedListOutputParser(ListOutputParser): @property def _type(self) -> str: return "numbered-list" + + +class MarkdownListOutputParser(ListOutputParser): + """Parse a markdown list.""" + + def get_format_instructions(self) -> str: + return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`" + + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + pattern = r"-\s([^\n]+)" + return re.findall(pattern, text) + + @property + def _type(self) -> str: + return "markdown-list" diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py index a25bbc38bd9..c85a31b9578 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_list_parser.py @@ -1,5 +1,6 @@ from langchain.output_parsers.list import ( CommaSeparatedListOutputParser, + MarkdownListOutputParser, NumberedListOutputParser, ) @@ -30,3 +31,19 @@ def test_numbered_list() -> None: assert parser.parse(text1) == ["foo", "bar", "baz"] assert parser.parse(text2) == ["apple", "banana", "cherry"] assert parser.parse(text3) == [] + + +def test_markdown_list() -> None: + parser = MarkdownListOutputParser() + text1 = ( + "Your response should be a numbered list with each item on a new line." + "For example: \n- foo\n- bar\n- baz" + ) + + text2 = "Items:\n- apple\n- banana\n- cherry" + + text3 = "No items in the list." + + assert parser.parse(text1) == ["foo", "bar", "baz"] + assert parser.parse(text2) == ["apple", "banana", "cherry"] + assert parser.parse(text3) == []