Fix QuestionListOutputParser (#9738)

This PR fixes `QuestionListOutputParser` text splitting.

`QuestionListOutputParser` incorrectly splits numbered list text into
lines. If text doesn't end with `\n` , the regex doesn't capture the
last item. So it always returns `n - 1` items, and
`WebResearchRetriever.llm_chain` generates less queries than requested
in the search prompt.

How to reproduce:

```python
from langchain.retrievers.web_research import QuestionListOutputParser

parser = QuestionListOutputParser()

good = parser.parse(
    """1. This is line one.
    2. This is line two.
    """  # <-- !
)

bad = parser.parse(
    """1. This is line one.
    2. This is line two."""    # <-- No new line.
)

assert good.lines == ['1. This is line one.\n', '2. This is line two.\n'], good.lines
assert bad.lines == ['1. This is line one.\n', '2. This is line two.'], bad.lines
```

NOTE: Last item will not contain a line break but this seems ok because
the items are stripped in the
`WebResearchRetriever.clean_search_query()`.
This commit is contained in:
Sergey Kozlov 2023-08-25 14:47:17 +06:00 committed by GitHub
parent d04fe0d3ea
commit 135cb86215
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 1 deletions

View File

@ -61,7 +61,7 @@ class QuestionListOutputParser(PydanticOutputParser):
super().__init__(pydantic_object=LineList)
def parse(self, text: str) -> LineList:
lines = re.findall(r"\d+\..*?\n", text)
lines = re.findall(r"\d+\..*?(?:\n|$)", text)
return LineList(lines=lines)

View File

@ -0,0 +1,36 @@
from typing import List
import pytest
from langchain.retrievers.web_research import QuestionListOutputParser
@pytest.mark.parametrize(
"text,expected",
(
(
"1. Line one.\n",
["1. Line one.\n"],
),
(
"1. Line one.",
["1. Line one."],
),
(
"1. Line one.\n2. Line two.\n",
["1. Line one.\n", "2. Line two.\n"],
),
(
"1. Line one.\n2. Line two.",
["1. Line one.\n", "2. Line two."],
),
(
"1. Line one.\n2. Line two.\n3. Line three.",
["1. Line one.\n", "2. Line two.\n", "3. Line three."],
),
),
)
def test_list_output_parser(text: str, expected: List[str]) -> None:
parser = QuestionListOutputParser()
result = parser.parse(text)
assert result.lines == expected