mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
Harrison/fix text splitter (#1511)
Co-authored-by: ajaysolanky <ajsolanky@gmail.com> Co-authored-by: Ajay Solanky <ajaysolanky@saw-l14668307kd.myfiosgateway.com>
This commit is contained in:
parent
e3354404ad
commit
064741db58
@ -71,12 +71,17 @@ class TextSplitter(ABC):
|
|||||||
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
||||||
# We now want to combine these smaller pieces into medium size
|
# We now want to combine these smaller pieces into medium size
|
||||||
# chunks to send to the LLM.
|
# chunks to send to the LLM.
|
||||||
|
separator_len = self._length_function(separator)
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
current_doc: List[str] = []
|
current_doc: List[str] = []
|
||||||
total = 0
|
total = 0
|
||||||
for d in splits:
|
for d in splits:
|
||||||
_len = self._length_function(d)
|
_len = self._length_function(d)
|
||||||
if total + _len >= self._chunk_size:
|
if (
|
||||||
|
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||||
|
> self._chunk_size
|
||||||
|
):
|
||||||
if total > self._chunk_size:
|
if total > self._chunk_size:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Created a chunk of size {total}, "
|
f"Created a chunk of size {total}, "
|
||||||
@ -90,12 +95,16 @@ class TextSplitter(ABC):
|
|||||||
# - we have a larger chunk than in the chunk overlap
|
# - we have a larger chunk than in the chunk overlap
|
||||||
# - or if we still have any chunks and the length is long
|
# - or if we still have any chunks and the length is long
|
||||||
while total > self._chunk_overlap or (
|
while total > self._chunk_overlap or (
|
||||||
total + _len > self._chunk_size and total > 0
|
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||||
|
> self._chunk_size
|
||||||
|
and total > 0
|
||||||
):
|
):
|
||||||
total -= self._length_function(current_doc[0])
|
total -= self._length_function(current_doc[0]) + (
|
||||||
|
separator_len if len(current_doc) > 1 else 0
|
||||||
|
)
|
||||||
current_doc = current_doc[1:]
|
current_doc = current_doc[1:]
|
||||||
current_doc.append(d)
|
current_doc.append(d)
|
||||||
total += _len
|
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||||
doc = self._join_docs(current_doc, separator)
|
doc = self._join_docs(current_doc, separator)
|
||||||
if doc is not None:
|
if doc is not None:
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
@ -26,6 +26,15 @@ def test_character_text_splitter_empty_doc() -> None:
|
|||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_character_text_splitter_separtor_empty_doc() -> None:
|
||||||
|
"""Test edge cases are separators."""
|
||||||
|
text = "f b"
|
||||||
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
|
||||||
|
output = splitter.split_text(text)
|
||||||
|
expected_output = ["f", "b"]
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_character_text_splitter_long() -> None:
|
def test_character_text_splitter_long() -> None:
|
||||||
"""Test splitting by character count on long words."""
|
"""Test splitting by character count on long words."""
|
||||||
text = "foo bar baz a a"
|
text = "foo bar baz a a"
|
||||||
@ -99,7 +108,7 @@ Bye!\n\n-H."""
|
|||||||
"Harrison.",
|
"Harrison.",
|
||||||
"How? Are?",
|
"How? Are?",
|
||||||
"You?",
|
"You?",
|
||||||
"Okay then f",
|
"Okay then",
|
||||||
"f f f f.",
|
"f f f f.",
|
||||||
"This is a",
|
"This is a",
|
||||||
"a weird",
|
"a weird",
|
||||||
@ -107,8 +116,8 @@ Bye!\n\n-H."""
|
|||||||
"write, but",
|
"write, but",
|
||||||
"gotta test",
|
"gotta test",
|
||||||
"the",
|
"the",
|
||||||
"splitting",
|
"splittingg",
|
||||||
"gggg",
|
"ggg",
|
||||||
"some how.",
|
"some how.",
|
||||||
"Bye!\n\n-H.",
|
"Bye!\n\n-H.",
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user