DB-GPT/dbgpt/rag/text_splitter/token_splitter.py
2024-04-20 09:41:16 +08:00

186 lines
6.8 KiB
Python

"""Token splitter."""
from typing import Callable, List, Optional
from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr
from dbgpt.util.global_helper import globals_helper
from dbgpt.util.splitter_utils import split_by_char, split_by_sep
DEFAULT_METADATA_FORMAT_LEN = 2
DEFAULT_CHUNK_OVERLAP = 20
DEFAULT_CHUNK_SIZE = 1024
class TokenTextSplitter(BaseModel):
"""Implementation of splitting text that looks at word tokens."""
chunk_size: int = Field(
default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk."
)
chunk_overlap: int = Field(
default=DEFAULT_CHUNK_OVERLAP,
description="The token overlap of each chunk when splitting.",
)
separator: str = Field(
default=" ", description="Default separator for splitting into words"
)
backup_separators: List = Field(
default_factory=list, description="Additional separators for splitting."
)
# callback_manager: CallbackManager = Field(
# default_factory=CallbackManager, exclude=True
# )
tokenizer: Callable = Field(
default_factory=globals_helper.tokenizer, # type: ignore
description="Tokenizer for splitting words into tokens.",
exclude=True,
)
_split_fns: List[Callable] = PrivateAttr()
def __init__(
self,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
tokenizer: Optional[Callable] = None,
# callback_manager: Optional[CallbackManager] = None,
separator: str = " ",
backup_separators=None,
):
"""Initialize with parameters."""
if backup_separators is None:
backup_separators = ["\n"]
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
# callback_manager = callback_manager or CallbackManager([])
tokenizer = tokenizer or globals_helper.tokenizer
all_seps = [separator] + (backup_separators or [])
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=separator,
backup_separators=backup_separators,
# callback_manager=callback_manager,
tokenizer=tokenizer,
)
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
@classmethod
def class_name(cls) -> str:
"""Return the class name."""
return "TokenTextSplitter"
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
"""Split text into chunks, reserving space required for metadata str."""
metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN
effective_chunk_size = self.chunk_size - metadata_len
if effective_chunk_size <= 0:
raise ValueError(
f"Metadata length ({metadata_len}) is longer than chunk size "
f"({self.chunk_size}). Consider increasing the chunk size or "
"decreasing the size of your metadata to avoid this."
)
elif effective_chunk_size < 50:
print(
f"Metadata length ({metadata_len}) is close to chunk size "
f"({self.chunk_size}). Resulting chunks are less than 50 tokens. "
"Consider increasing the chunk size or decreasing the size of "
"your metadata to avoid this.",
flush=True,
)
return self._split_text(text, chunk_size=effective_chunk_size)
def split_text(self, text: str) -> List[str]:
"""Split text into chunks."""
return self._split_text(text, chunk_size=self.chunk_size)
def _split_text(self, text: str, chunk_size: int) -> List[str]:
"""Split text into chunks up to chunk_size."""
if text == "":
return []
splits = self._split(text, chunk_size)
chunks = self._merge(splits, chunk_size)
return chunks
def _split(self, text: str, chunk_size: int) -> List[str]:
"""Break text into splits that are smaller than chunk size.
The order of splitting is:
1. split by separator
2. split by backup separators (if any)
3. split by characters
NOTE: the splits contain the separators.
"""
if len(self.tokenizer(text)) <= chunk_size:
return [text]
for split_fn in self._split_fns:
splits = split_fn(text)
if len(splits) > 1:
break
new_splits = []
for split in splits:
split_len = len(self.tokenizer(split))
if split_len <= chunk_size:
new_splits.append(split)
else:
# recursively split
new_splits.extend(self._split(split, chunk_size=chunk_size))
return new_splits
def _merge(self, splits: List[str], chunk_size: int) -> List[str]:
"""Merge splits into chunks.
The high-level idea is to keep adding splits to a chunk until we
exceed the chunk size, then we start a new chunk with overlap.
When we start a new chunk, we pop off the first element of the previous
chunk until the total length is less than the chunk size.
"""
chunks: List[str] = []
cur_chunk: List[str] = []
cur_len = 0
for split in splits:
split_len = len(self.tokenizer(split))
if split_len > chunk_size:
print(
f"Got a split of size {split_len}, ",
f"larger than chunk size {chunk_size}.",
)
# if we exceed the chunk size after adding the new split, then
# we need to end the current chunk and start a new one
if cur_len + split_len > chunk_size:
# end the previous chunk
chunk = "".join(cur_chunk).strip()
if chunk:
chunks.append(chunk)
# start a new chunk with overlap
# keep popping off the first element of the previous chunk until:
# 1. the current chunk length is less than chunk overlap
# 2. the total length is less than chunk size
while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size:
# pop off the first element
first_chunk = cur_chunk.pop(0)
cur_len -= len(self.tokenizer(first_chunk))
cur_chunk.append(split)
cur_len += split_len
# handle the last chunk
chunk = "".join(cur_chunk).strip()
if chunk:
chunks.append(chunk)
return chunks