diff --git a/docs/docs/how_to/MultiQueryRetriever.ipynb b/docs/docs/how_to/MultiQueryRetriever.ipynb index f1377124aa9..1574eb1dd63 100644 --- a/docs/docs/how_to/MultiQueryRetriever.ipynb +++ b/docs/docs/how_to/MultiQueryRetriever.ipynb @@ -153,7 +153,7 @@ "\n", " def parse(self, text: str) -> List[str]:\n", " lines = text.strip().split(\"\\n\")\n", - " return lines\n", + " return list(filter(None, lines)) # Remove empty lines\n", "\n", "\n", "output_parser = LineListOutputParser()\n", diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index 3d1a36be476..23ba88e2a53 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -24,7 +24,7 @@ class LineListOutputParser(BaseOutputParser[List[str]]): def parse(self, text: str) -> List[str]: lines = text.strip().split("\n") - return lines + return list(filter(None, lines)) # Remove empty lines # Default prompt diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py index 8f80e77e79b..d3529e8d97c 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py @@ -3,7 +3,7 @@ from typing import List import pytest as pytest from langchain_core.documents import Document -from langchain.retrievers.multi_query import _unique_documents +from langchain.retrievers.multi_query import LineListOutputParser, _unique_documents @pytest.mark.parametrize( @@ -38,3 +38,16 @@ from langchain.retrievers.multi_query import _unique_documents ) def test__unique_documents(documents: List[Document], expected: List[Document]) -> None: assert _unique_documents(documents) == expected + + +@pytest.mark.parametrize( + "text,expected", + [ + ("foo\nbar\nbaz", ["foo", "bar", "baz"]), + ("foo\nbar\nbaz\n", ["foo", "bar", "baz"]), + ("foo\n\nbar", ["foo", "bar"]), + ], +) +def test_line_list_output_parser(text: str, expected: List[str]) -> None: + parser = LineListOutputParser() + assert parser.parse(text) == expected