diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 6b77fc53504..6d2ea2e8733 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -47,7 +47,7 @@ def _split_text_with_regex( splits += _splits[-1:] splits = [_splits[0]] + splits else: - splits = text.split(separator) + splits = re.split(separator, text) else: splits = list(text) return [s for s in splits if s != ""] diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index 75312c2c54e..e8dee47ffb5 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -80,6 +80,31 @@ def test_character_text_splitter_longer_words() -> None: assert output == expected_output +def test_character_text_splitter_keep_separator_regex() -> None: + """Test splitting by characters while keeping the separator + that is a regex special character. + """ + text = "foo.bar.baz.123" + splitter = CharacterTextSplitter( + separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=True + ) + output = splitter.split_text(text) + expected_output = ["foo", ".bar", ".baz", ".123"] + assert output == expected_output + + +def test_character_text_splitter_discard_separator_regex() -> None: + """Test splitting by characters discarding the separator + that is a regex special character.""" + text = "foo.bar.baz.123" + splitter = CharacterTextSplitter( + separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=False + ) + output = splitter.split_text(text) + expected_output = ["foo", "bar", "baz", "123"] + assert output == expected_output + + def test_character_text_splitting_args() -> None: """Test invalid arguments.""" with pytest.raises(ValueError):