From bc557a56639cdd1ac3cbd4eb4704ae8701b27e26 Mon Sep 17 00:00:00 2001 From: ccurme Date: Fri, 23 Aug 2024 13:22:02 -0400 Subject: [PATCH] text-splitters[patch]: fix typing for `keep_separator` (#25706) --- .../langchain_text_splitters/character.py | 2 +- .../tests/unit_tests/test_text_splitters.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/libs/text-splitters/langchain_text_splitters/character.py b/libs/text-splitters/langchain_text_splitters/character.py index 12f69484c15..85124b39de9 100644 --- a/libs/text-splitters/langchain_text_splitters/character.py +++ b/libs/text-splitters/langchain_text_splitters/character.py @@ -65,7 +65,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): def __init__( self, separators: Optional[List[str]] = None, - keep_separator: bool = True, + keep_separator: Union[bool, Literal["start", "end"]] = True, is_separator_regex: bool = False, **kwargs: Any, ) -> None: diff --git a/libs/text-splitters/tests/unit_tests/test_text_splitters.py b/libs/text-splitters/tests/unit_tests/test_text_splitters.py index 1faee31cd0d..95f170d52b7 100644 --- a/libs/text-splitters/tests/unit_tests/test_text_splitters.py +++ b/libs/text-splitters/tests/unit_tests/test_text_splitters.py @@ -180,6 +180,30 @@ def test_character_text_splitter_discard_separator_regex( assert output == expected_output +def test_recursive_character_text_splitter_keep_separators() -> None: + split_tags = [",", "."] + query = "Apple,banana,orange and tomato." + # start + splitter = RecursiveCharacterTextSplitter( + chunk_size=10, + chunk_overlap=0, + separators=split_tags, + keep_separator="start", + ) + result = splitter.split_text(query) + assert result == ["Apple", ",banana", ",orange and tomato", "."] + + # end + splitter = RecursiveCharacterTextSplitter( + chunk_size=10, + chunk_overlap=0, + separators=split_tags, + keep_separator="end", + ) + result = splitter.split_text(query) + assert result == ["Apple,", "banana,", "orange and tomato."] + + def test_character_text_splitting_args() -> None: """Test invalid arguments.""" with pytest.raises(ValueError):