"""Token splitter.""" from typing import Callable, List, Optional from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr from dbgpt.util.global_helper import globals_helper from dbgpt.util.splitter_utils import split_by_char, split_by_sep DEFAULT_METADATA_FORMAT_LEN = 2 DEFAULT_CHUNK_OVERLAP = 20 DEFAULT_CHUNK_SIZE = 1024 class TokenTextSplitter(BaseModel): """Implementation of splitting text that looks at word tokens.""" chunk_size: int = Field( default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk." ) chunk_overlap: int = Field( default=DEFAULT_CHUNK_OVERLAP, description="The token overlap of each chunk when splitting.", ) separator: str = Field( default=" ", description="Default separator for splitting into words" ) backup_separators: List = Field( default_factory=list, description="Additional separators for splitting." ) # callback_manager: CallbackManager = Field( # default_factory=CallbackManager, exclude=True # ) tokenizer: Callable = Field( default_factory=globals_helper.tokenizer, # type: ignore description="Tokenizer for splitting words into tokens.", exclude=True, ) _split_fns: List[Callable] = PrivateAttr() def __init__( self, chunk_size: int = DEFAULT_CHUNK_SIZE, chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, tokenizer: Optional[Callable] = None, # callback_manager: Optional[CallbackManager] = None, separator: str = " ", backup_separators=None, ): """Initialize with parameters.""" if backup_separators is None: backup_separators = ["\n"] if chunk_overlap > chunk_size: raise ValueError( f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) # callback_manager = callback_manager or CallbackManager([]) tokenizer = tokenizer or globals_helper.tokenizer all_seps = [separator] + (backup_separators or []) super().__init__( chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator, backup_separators=backup_separators, # callback_manager=callback_manager, tokenizer=tokenizer, ) self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()] @classmethod def class_name(cls) -> str: """Return the class name.""" return "TokenTextSplitter" def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]: """Split text into chunks, reserving space required for metadata str.""" metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN effective_chunk_size = self.chunk_size - metadata_len if effective_chunk_size <= 0: raise ValueError( f"Metadata length ({metadata_len}) is longer than chunk size " f"({self.chunk_size}). Consider increasing the chunk size or " "decreasing the size of your metadata to avoid this." ) elif effective_chunk_size < 50: print( f"Metadata length ({metadata_len}) is close to chunk size " f"({self.chunk_size}). Resulting chunks are less than 50 tokens. " "Consider increasing the chunk size or decreasing the size of " "your metadata to avoid this.", flush=True, ) return self._split_text(text, chunk_size=effective_chunk_size) def split_text(self, text: str) -> List[str]: """Split text into chunks.""" return self._split_text(text, chunk_size=self.chunk_size) def _split_text(self, text: str, chunk_size: int) -> List[str]: """Split text into chunks up to chunk_size.""" if text == "": return [] splits = self._split(text, chunk_size) chunks = self._merge(splits, chunk_size) return chunks def _split(self, text: str, chunk_size: int) -> List[str]: """Break text into splits that are smaller than chunk size. The order of splitting is: 1. split by separator 2. split by backup separators (if any) 3. split by characters NOTE: the splits contain the separators. """ if len(self.tokenizer(text)) <= chunk_size: return [text] for split_fn in self._split_fns: splits = split_fn(text) if len(splits) > 1: break new_splits = [] for split in splits: split_len = len(self.tokenizer(split)) if split_len <= chunk_size: new_splits.append(split) else: # recursively split new_splits.extend(self._split(split, chunk_size=chunk_size)) return new_splits def _merge(self, splits: List[str], chunk_size: int) -> List[str]: """Merge splits into chunks. The high-level idea is to keep adding splits to a chunk until we exceed the chunk size, then we start a new chunk with overlap. When we start a new chunk, we pop off the first element of the previous chunk until the total length is less than the chunk size. """ chunks: List[str] = [] cur_chunk: List[str] = [] cur_len = 0 for split in splits: split_len = len(self.tokenizer(split)) if split_len > chunk_size: print( f"Got a split of size {split_len}, ", f"larger than chunk size {chunk_size}.", ) # if we exceed the chunk size after adding the new split, then # we need to end the current chunk and start a new one if cur_len + split_len > chunk_size: # end the previous chunk chunk = "".join(cur_chunk).strip() if chunk: chunks.append(chunk) # start a new chunk with overlap # keep popping off the first element of the previous chunk until: # 1. the current chunk length is less than chunk overlap # 2. the total length is less than chunk size while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size: # pop off the first element first_chunk = cur_chunk.pop(0) cur_len -= len(self.tokenizer(first_chunk)) cur_chunk.append(split) cur_len += split_len # handle the last chunk chunk = "".join(cur_chunk).strip() if chunk: chunks.append(chunk) return chunks