mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 21:09:00 +00:00
text-splitters[patch]: Extend TextSplitter:keep_separator functionality (#21130)
**Description:** Added extra functionality to `CharacterTextSplitter`, `TextSplitter` classes. The user can select whether to append the separator to the previous chunk with `keep_separator='end' ` or else prepend to the next chunk. Previous functionality prepended by default to next chunk. **Issue:** Fixes #20908 --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
b859765752
commit
c3bcfad66d
@ -35,7 +35,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
chunk_size: int = 4000,
|
chunk_size: int = 4000,
|
||||||
chunk_overlap: int = 200,
|
chunk_overlap: int = 200,
|
||||||
length_function: Callable[[str], int] = len,
|
length_function: Callable[[str], int] = len,
|
||||||
keep_separator: bool = False,
|
keep_separator: Union[bool, Literal["start", "end"]] = False,
|
||||||
add_start_index: bool = False,
|
add_start_index: bool = False,
|
||||||
strip_whitespace: bool = True,
|
strip_whitespace: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -45,7 +45,8 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
chunk_size: Maximum size of chunks to return
|
chunk_size: Maximum size of chunks to return
|
||||||
chunk_overlap: Overlap in characters between chunks
|
chunk_overlap: Overlap in characters between chunks
|
||||||
length_function: Function that measures the length of given chunks
|
length_function: Function that measures the length of given chunks
|
||||||
keep_separator: Whether to keep the separator in the chunks
|
keep_separator: Whether to keep the separator and where to place it
|
||||||
|
in each corresponding chunk (True='start')
|
||||||
add_start_index: If `True`, includes chunk's start index in metadata
|
add_start_index: If `True`, includes chunk's start index in metadata
|
||||||
strip_whitespace: If `True`, strips whitespace from the start and end of
|
strip_whitespace: If `True`, strips whitespace from the start and end of
|
||||||
every document
|
every document
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
from langchain_text_splitters.base import Language, TextSplitter
|
from langchain_text_splitters.base import Language, TextSplitter
|
||||||
|
|
||||||
@ -29,17 +29,25 @@ class CharacterTextSplitter(TextSplitter):
|
|||||||
|
|
||||||
|
|
||||||
def _split_text_with_regex(
|
def _split_text_with_regex(
|
||||||
text: str, separator: str, keep_separator: bool
|
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
# Now that we have the separator, split the text
|
# Now that we have the separator, split the text
|
||||||
if separator:
|
if separator:
|
||||||
if keep_separator:
|
if keep_separator:
|
||||||
# The parentheses in the pattern keep the delimiters in the result.
|
# The parentheses in the pattern keep the delimiters in the result.
|
||||||
_splits = re.split(f"({separator})", text)
|
_splits = re.split(f"({separator})", text)
|
||||||
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
splits = (
|
||||||
|
([_splits[i] + _splits[i + 1] for i in range(0, len(_splits) - 1, 2)])
|
||||||
|
if keep_separator == "end"
|
||||||
|
else ([_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)])
|
||||||
|
)
|
||||||
if len(_splits) % 2 == 0:
|
if len(_splits) % 2 == 0:
|
||||||
splits += _splits[-1:]
|
splits += _splits[-1:]
|
||||||
splits = [_splits[0]] + splits
|
splits = (
|
||||||
|
(splits + [_splits[-1]])
|
||||||
|
if keep_separator == "end"
|
||||||
|
else ([_splits[0]] + splits)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
splits = re.split(separator, text)
|
splits = re.split(separator, text)
|
||||||
else:
|
else:
|
||||||
|
@ -112,6 +112,50 @@ def test_character_text_splitter_keep_separator_regex(
|
|||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||||
|
)
|
||||||
|
def test_character_text_splitter_keep_separator_regex_start(
|
||||||
|
separator: str, is_separator_regex: bool
|
||||||
|
) -> None:
|
||||||
|
"""Test splitting by characters while keeping the separator
|
||||||
|
that is a regex special character and placing it at the start of each chunk.
|
||||||
|
"""
|
||||||
|
text = "foo.bar.baz.123"
|
||||||
|
splitter = CharacterTextSplitter(
|
||||||
|
separator=separator,
|
||||||
|
chunk_size=1,
|
||||||
|
chunk_overlap=0,
|
||||||
|
keep_separator="start",
|
||||||
|
is_separator_regex=is_separator_regex,
|
||||||
|
)
|
||||||
|
output = splitter.split_text(text)
|
||||||
|
expected_output = ["foo", ".bar", ".baz", ".123"]
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||||
|
)
|
||||||
|
def test_character_text_splitter_keep_separator_regex_end(
|
||||||
|
separator: str, is_separator_regex: bool
|
||||||
|
) -> None:
|
||||||
|
"""Test splitting by characters while keeping the separator
|
||||||
|
that is a regex special character and placing it at the end of each chunk.
|
||||||
|
"""
|
||||||
|
text = "foo.bar.baz.123"
|
||||||
|
splitter = CharacterTextSplitter(
|
||||||
|
separator=separator,
|
||||||
|
chunk_size=1,
|
||||||
|
chunk_overlap=0,
|
||||||
|
keep_separator="end",
|
||||||
|
is_separator_regex=is_separator_regex,
|
||||||
|
)
|
||||||
|
output = splitter.split_text(text)
|
||||||
|
expected_output = ["foo.", "bar.", "baz.", "123"]
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user