mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-02 13:08:57 +00:00
Fix QuestionListOutputParser (#9738)
This PR fixes `QuestionListOutputParser` text splitting. `QuestionListOutputParser` incorrectly splits numbered list text into lines. If text doesn't end with `\n` , the regex doesn't capture the last item. So it always returns `n - 1` items, and `WebResearchRetriever.llm_chain` generates less queries than requested in the search prompt. How to reproduce: ```python from langchain.retrievers.web_research import QuestionListOutputParser parser = QuestionListOutputParser() good = parser.parse( """1. This is line one. 2. This is line two. """ # <-- ! ) bad = parser.parse( """1. This is line one. 2. This is line two.""" # <-- No new line. ) assert good.lines == ['1. This is line one.\n', '2. This is line two.\n'], good.lines assert bad.lines == ['1. This is line one.\n', '2. This is line two.'], bad.lines ``` NOTE: Last item will not contain a line break but this seems ok because the items are stripped in the `WebResearchRetriever.clean_search_query()`.
This commit is contained in:
parent
d04fe0d3ea
commit
135cb86215
@ -61,7 +61,7 @@ class QuestionListOutputParser(PydanticOutputParser):
|
||||
super().__init__(pydantic_object=LineList)
|
||||
|
||||
def parse(self, text: str) -> LineList:
|
||||
lines = re.findall(r"\d+\..*?\n", text)
|
||||
lines = re.findall(r"\d+\..*?(?:\n|$)", text)
|
||||
return LineList(lines=lines)
|
||||
|
||||
|
||||
|
@ -0,0 +1,36 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers.web_research import QuestionListOutputParser
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,expected",
|
||||
(
|
||||
(
|
||||
"1. Line one.\n",
|
||||
["1. Line one.\n"],
|
||||
),
|
||||
(
|
||||
"1. Line one.",
|
||||
["1. Line one."],
|
||||
),
|
||||
(
|
||||
"1. Line one.\n2. Line two.\n",
|
||||
["1. Line one.\n", "2. Line two.\n"],
|
||||
),
|
||||
(
|
||||
"1. Line one.\n2. Line two.",
|
||||
["1. Line one.\n", "2. Line two."],
|
||||
),
|
||||
(
|
||||
"1. Line one.\n2. Line two.\n3. Line three.",
|
||||
["1. Line one.\n", "2. Line two.\n", "3. Line three."],
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_list_output_parser(text: str, expected: List[str]) -> None:
|
||||
parser = QuestionListOutputParser()
|
||||
result = parser.parse(text)
|
||||
assert result.lines == expected
|
Loading…
Reference in New Issue
Block a user