diff --git a/libs/text-splitters/langchain_text_splitters/base.py b/libs/text-splitters/langchain_text_splitters/base.py index b0fb33caa2d..bdf7ae7be2d 100644 --- a/libs/text-splitters/langchain_text_splitters/base.py +++ b/libs/text-splitters/langchain_text_splitters/base.py @@ -35,7 +35,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): chunk_size: int = 4000, chunk_overlap: int = 200, length_function: Callable[[str], int] = len, - keep_separator: bool = False, + keep_separator: Union[bool, Literal["start", "end"]] = False, add_start_index: bool = False, strip_whitespace: bool = True, ) -> None: @@ -45,7 +45,8 @@ class TextSplitter(BaseDocumentTransformer, ABC): chunk_size: Maximum size of chunks to return chunk_overlap: Overlap in characters between chunks length_function: Function that measures the length of given chunks - keep_separator: Whether to keep the separator in the chunks + keep_separator: Whether to keep the separator and where to place it + in each corresponding chunk (True='start') add_start_index: If `True`, includes chunk's start index in metadata strip_whitespace: If `True`, strips whitespace from the start and end of every document diff --git a/libs/text-splitters/langchain_text_splitters/character.py b/libs/text-splitters/langchain_text_splitters/character.py index 0f2ce97bcb0..6783f98363a 100644 --- a/libs/text-splitters/langchain_text_splitters/character.py +++ b/libs/text-splitters/langchain_text_splitters/character.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Any, List, Optional +from typing import Any, List, Literal, Optional, Union from langchain_text_splitters.base import Language, TextSplitter @@ -29,17 +29,25 @@ class CharacterTextSplitter(TextSplitter): def _split_text_with_regex( - text: str, separator: str, keep_separator: bool + text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]] ) -> List[str]: # Now that we have the separator, split the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. _splits = re.split(f"({separator})", text) - splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] + splits = ( + ([_splits[i] + _splits[i + 1] for i in range(0, len(_splits) - 1, 2)]) + if keep_separator == "end" + else ([_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]) + ) if len(_splits) % 2 == 0: splits += _splits[-1:] - splits = [_splits[0]] + splits + splits = ( + (splits + [_splits[-1]]) + if keep_separator == "end" + else ([_splits[0]] + splits) + ) else: splits = re.split(separator, text) else: diff --git a/libs/text-splitters/tests/unit_tests/test_text_splitters.py b/libs/text-splitters/tests/unit_tests/test_text_splitters.py index 3d88d786fb3..062f4d089d1 100644 --- a/libs/text-splitters/tests/unit_tests/test_text_splitters.py +++ b/libs/text-splitters/tests/unit_tests/test_text_splitters.py @@ -112,6 +112,50 @@ def test_character_text_splitter_keep_separator_regex( assert output == expected_output +@pytest.mark.parametrize( + "separator, is_separator_regex", [(re.escape("."), True), (".", False)] +) +def test_character_text_splitter_keep_separator_regex_start( + separator: str, is_separator_regex: bool +) -> None: + """Test splitting by characters while keeping the separator + that is a regex special character and placing it at the start of each chunk. + """ + text = "foo.bar.baz.123" + splitter = CharacterTextSplitter( + separator=separator, + chunk_size=1, + chunk_overlap=0, + keep_separator="start", + is_separator_regex=is_separator_regex, + ) + output = splitter.split_text(text) + expected_output = ["foo", ".bar", ".baz", ".123"] + assert output == expected_output + + +@pytest.mark.parametrize( + "separator, is_separator_regex", [(re.escape("."), True), (".", False)] +) +def test_character_text_splitter_keep_separator_regex_end( + separator: str, is_separator_regex: bool +) -> None: + """Test splitting by characters while keeping the separator + that is a regex special character and placing it at the end of each chunk. + """ + text = "foo.bar.baz.123" + splitter = CharacterTextSplitter( + separator=separator, + chunk_size=1, + chunk_overlap=0, + keep_separator="end", + is_separator_regex=is_separator_regex, + ) + output = splitter.split_text(text) + expected_output = ["foo.", "bar.", "baz.", "123"] + assert output == expected_output + + @pytest.mark.parametrize( "separator, is_separator_regex", [(re.escape("."), True), (".", False)] )