mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 12:39:32 +00:00
add numbered list parser (#9837)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -34,3 +35,29 @@ class CommaSeparatedListOutputParser(ListOutputParser):
|
|||||||
def parse(self, text: str) -> List[str]:
|
def parse(self, text: str) -> List[str]:
|
||||||
"""Parse the output of an LLM call."""
|
"""Parse the output of an LLM call."""
|
||||||
return text.strip().split(", ")
|
return text.strip().split(", ")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "comma-separated-list"
|
||||||
|
|
||||||
|
|
||||||
|
class NumberedListOutputParser(ListOutputParser):
|
||||||
|
"""Parse a numbered list."""
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
return (
|
||||||
|
"Your response should be a numbered list with each item on a new line. "
|
||||||
|
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse(self, text: str) -> List[str]:
|
||||||
|
"""Parse the output of an LLM call."""
|
||||||
|
pattern = r"\d+\.\s([^\n]+)"
|
||||||
|
|
||||||
|
# Extract the text of each item
|
||||||
|
matches = re.findall(pattern, text)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "numbered-list"
|
||||||
|
@@ -1,4 +1,7 @@
|
|||||||
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
from langchain.output_parsers.list import (
|
||||||
|
CommaSeparatedListOutputParser,
|
||||||
|
NumberedListOutputParser,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_single_item() -> None:
|
def test_single_item() -> None:
|
||||||
@@ -11,3 +14,19 @@ def test_multiple_items() -> None:
|
|||||||
"""Test that a string with multiple comma-separated items is parsed to a list."""
|
"""Test that a string with multiple comma-separated items is parsed to a list."""
|
||||||
parser = CommaSeparatedListOutputParser()
|
parser = CommaSeparatedListOutputParser()
|
||||||
assert parser.parse("foo, bar, baz") == ["foo", "bar", "baz"]
|
assert parser.parse("foo, bar, baz") == ["foo", "bar", "baz"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_numbered_list() -> None:
|
||||||
|
parser = NumberedListOutputParser()
|
||||||
|
text1 = (
|
||||||
|
"Your response should be a numbered list with each item on a new line. "
|
||||||
|
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
|
||||||
|
)
|
||||||
|
|
||||||
|
text2 = "Items:\n\n1. apple\n\n2. banana\n\n3. cherry"
|
||||||
|
|
||||||
|
text3 = "No items in the list."
|
||||||
|
|
||||||
|
assert parser.parse(text1) == ["foo", "bar", "baz"]
|
||||||
|
assert parser.parse(text2) == ["apple", "banana", "cherry"]
|
||||||
|
assert parser.parse(text3) == []
|
||||||
|
Reference in New Issue
Block a user