diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index 858ba86c79f..ebaca8f8ca9 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -1,9 +1,11 @@ from __future__ import annotations +import csv import re from abc import abstractmethod from collections import deque from collections.abc import AsyncIterator, Iterator +from io import StringIO from typing import Optional as Optional from typing import TypeVar, Union @@ -162,7 +164,14 @@ class CommaSeparatedListOutputParser(ListOutputParser): Returns: A list of strings. """ - return [part.strip() for part in text.split(",")] + try: + reader = csv.reader( + StringIO(text), quotechar='"', delimiter=",", skipinitialspace=True + ) + return [item for sublist in reader for item in sublist] + except csv.Error: + # keep old logic for backup + return [part.strip() for part in text.split(",")] @property def _type(self) -> str: diff --git a/libs/core/tests/unit_tests/output_parsers/test_list_parser.py b/libs/core/tests/unit_tests/output_parsers/test_list_parser.py index 3f43edfa2ae..11bd11b6a0b 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_list_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_list_parser.py @@ -64,6 +64,25 @@ def test_multiple_items() -> None: assert list(parser.transform(iter([text]))) == [[a] for a in expected] +def test_multiple_items_with_comma() -> None: + """Test that a string with multiple comma-separated items with 1 item containing a + comma is parsed to a list.""" + parser = CommaSeparatedListOutputParser() + text = '"foo, foo2",bar,baz' + expected = ["foo, foo2", "bar", "baz"] + + assert parser.parse(text) == expected + assert add(parser.transform(t for t in text)) == expected + assert list(parser.transform(t for t in text)) == [[a] for a in expected] + assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [ + [a] for a in expected + ] + assert list( + parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" "))) + ) == [[a] for a in expected] + assert list(parser.transform(iter([text]))) == [[a] for a in expected] + + def test_numbered_list() -> None: parser = NumberedListOutputParser() text1 = (