From 0c7a5cb20684386dbf62ef10870cb2d219cf33f5 Mon Sep 17 00:00:00 2001 From: Sasmitha Manathunga <70096033+mmz-001@users.noreply.github.com> Date: Thu, 6 Jul 2023 19:00:03 +0530 Subject: [PATCH] Fix inconsistent behavior of `CharacterTextSplitter` when changing `keep_separator` (#7263) - Description: - When `keep_separator` is `True` the `_split_text_with_regex()` method in `text_splitter` uses regex to split, but when `keep_separator` is `False` it uses `str.split()`. This causes problems when the separator is a special regex character like `.` or `*`. This PR fixes that by using `re.split()` in both cases. - Issue: #7262 - Tag maintainer: @baskaryan --- langchain/text_splitter.py | 2 +- tests/unit_tests/test_text_splitter.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 6b77fc53504..6d2ea2e8733 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -47,7 +47,7 @@ def _split_text_with_regex( splits += _splits[-1:] splits = [_splits[0]] + splits else: - splits = text.split(separator) + splits = re.split(separator, text) else: splits = list(text) return [s for s in splits if s != ""] diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index 75312c2c54e..e8dee47ffb5 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -80,6 +80,31 @@ def test_character_text_splitter_longer_words() -> None: assert output == expected_output +def test_character_text_splitter_keep_separator_regex() -> None: + """Test splitting by characters while keeping the separator + that is a regex special character. + """ + text = "foo.bar.baz.123" + splitter = CharacterTextSplitter( + separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=True + ) + output = splitter.split_text(text) + expected_output = ["foo", ".bar", ".baz", ".123"] + assert output == expected_output + + +def test_character_text_splitter_discard_separator_regex() -> None: + """Test splitting by characters discarding the separator + that is a regex special character.""" + text = "foo.bar.baz.123" + splitter = CharacterTextSplitter( + separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=False + ) + output = splitter.split_text(text) + expected_output = ["foo", "bar", "baz", "123"] + assert output == expected_output + + def test_character_text_splitting_args() -> None: """Test invalid arguments.""" with pytest.raises(ValueError):