mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
fix text splitter (#375)
This commit is contained in:
parent
3474f39e21
commit
e7b625fe03
@ -1,9 +1,12 @@
|
||||
"""Functionality for splitting text."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, List
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class TextSplitter(ABC):
|
||||
"""Interface for splitting text into chunks."""
|
||||
@ -37,13 +40,20 @@ class TextSplitter(ABC):
|
||||
current_doc: List[str] = []
|
||||
total = 0
|
||||
for d in splits:
|
||||
if total >= self._chunk_size:
|
||||
docs.append(self._separator.join(current_doc))
|
||||
while total > self._chunk_overlap:
|
||||
total -= self._length_function(current_doc[0])
|
||||
current_doc = current_doc[1:]
|
||||
_len = self._length_function(d)
|
||||
if total + _len >= self._chunk_size:
|
||||
if total > self._chunk_size:
|
||||
logger.warning(
|
||||
f"Created a chunk of size {total}, "
|
||||
f"which is longer than the specified {self._chunk_size}"
|
||||
)
|
||||
if len(current_doc) > 0:
|
||||
docs.append(self._separator.join(current_doc))
|
||||
while total > self._chunk_overlap:
|
||||
total -= self._length_function(current_doc[0])
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += self._length_function(d)
|
||||
total += _len
|
||||
docs.append(self._separator.join(current_doc))
|
||||
return docs
|
||||
|
||||
|
@ -7,12 +7,21 @@ from langchain.text_splitter import CharacterTextSplitter
|
||||
def test_character_text_splitter() -> None:
|
||||
"""Test splitting by character count."""
|
||||
text = "foo bar baz 123"
|
||||
splitter = CharacterTextSplitter(separator=" ", chunk_size=5, chunk_overlap=3)
|
||||
splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3)
|
||||
output = splitter.split_text(text)
|
||||
expected_output = ["foo bar", "bar baz", "baz 123"]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_character_text_splitter_long() -> None:
|
||||
"""Test splitting by character count on long words."""
|
||||
text = "foo bar baz a a"
|
||||
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
|
||||
output = splitter.split_text(text)
|
||||
expected_output = ["foo", "bar", "baz", "a a"]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_character_text_splitter_longer_words() -> None:
|
||||
"""Test splitting by characters when splits not found easily."""
|
||||
text = "foo bar baz 123"
|
||||
|
Loading…
Reference in New Issue
Block a user