mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 03:41:43 +00:00
186 lines
6.8 KiB
Python
186 lines
6.8 KiB
Python
"""Token splitter."""
|
|
from typing import Callable, List, Optional
|
|
|
|
from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr
|
|
from dbgpt.util.global_helper import globals_helper
|
|
from dbgpt.util.splitter_utils import split_by_char, split_by_sep
|
|
|
|
DEFAULT_METADATA_FORMAT_LEN = 2
|
|
DEFAULT_CHUNK_OVERLAP = 20
|
|
DEFAULT_CHUNK_SIZE = 1024
|
|
|
|
|
|
class TokenTextSplitter(BaseModel):
|
|
"""Implementation of splitting text that looks at word tokens."""
|
|
|
|
chunk_size: int = Field(
|
|
default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk."
|
|
)
|
|
chunk_overlap: int = Field(
|
|
default=DEFAULT_CHUNK_OVERLAP,
|
|
description="The token overlap of each chunk when splitting.",
|
|
)
|
|
separator: str = Field(
|
|
default=" ", description="Default separator for splitting into words"
|
|
)
|
|
backup_separators: List = Field(
|
|
default_factory=list, description="Additional separators for splitting."
|
|
)
|
|
# callback_manager: CallbackManager = Field(
|
|
# default_factory=CallbackManager, exclude=True
|
|
# )
|
|
tokenizer: Callable = Field(
|
|
default_factory=globals_helper.tokenizer, # type: ignore
|
|
description="Tokenizer for splitting words into tokens.",
|
|
exclude=True,
|
|
)
|
|
|
|
_split_fns: List[Callable] = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
|
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
|
tokenizer: Optional[Callable] = None,
|
|
# callback_manager: Optional[CallbackManager] = None,
|
|
separator: str = " ",
|
|
backup_separators=None,
|
|
):
|
|
"""Initialize with parameters."""
|
|
if backup_separators is None:
|
|
backup_separators = ["\n"]
|
|
if chunk_overlap > chunk_size:
|
|
raise ValueError(
|
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
|
f"({chunk_size}), should be smaller."
|
|
)
|
|
# callback_manager = callback_manager or CallbackManager([])
|
|
tokenizer = tokenizer or globals_helper.tokenizer
|
|
|
|
all_seps = [separator] + (backup_separators or [])
|
|
|
|
super().__init__(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
separator=separator,
|
|
backup_separators=backup_separators,
|
|
# callback_manager=callback_manager,
|
|
tokenizer=tokenizer,
|
|
)
|
|
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
"""Return the class name."""
|
|
return "TokenTextSplitter"
|
|
|
|
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
|
|
"""Split text into chunks, reserving space required for metadata str."""
|
|
metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN
|
|
effective_chunk_size = self.chunk_size - metadata_len
|
|
if effective_chunk_size <= 0:
|
|
raise ValueError(
|
|
f"Metadata length ({metadata_len}) is longer than chunk size "
|
|
f"({self.chunk_size}). Consider increasing the chunk size or "
|
|
"decreasing the size of your metadata to avoid this."
|
|
)
|
|
elif effective_chunk_size < 50:
|
|
print(
|
|
f"Metadata length ({metadata_len}) is close to chunk size "
|
|
f"({self.chunk_size}). Resulting chunks are less than 50 tokens. "
|
|
"Consider increasing the chunk size or decreasing the size of "
|
|
"your metadata to avoid this.",
|
|
flush=True,
|
|
)
|
|
|
|
return self._split_text(text, chunk_size=effective_chunk_size)
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split text into chunks."""
|
|
return self._split_text(text, chunk_size=self.chunk_size)
|
|
|
|
def _split_text(self, text: str, chunk_size: int) -> List[str]:
|
|
"""Split text into chunks up to chunk_size."""
|
|
if text == "":
|
|
return []
|
|
|
|
splits = self._split(text, chunk_size)
|
|
chunks = self._merge(splits, chunk_size)
|
|
return chunks
|
|
|
|
def _split(self, text: str, chunk_size: int) -> List[str]:
|
|
"""Break text into splits that are smaller than chunk size.
|
|
|
|
The order of splitting is:
|
|
1. split by separator
|
|
2. split by backup separators (if any)
|
|
3. split by characters
|
|
|
|
NOTE: the splits contain the separators.
|
|
"""
|
|
if len(self.tokenizer(text)) <= chunk_size:
|
|
return [text]
|
|
|
|
for split_fn in self._split_fns:
|
|
splits = split_fn(text)
|
|
if len(splits) > 1:
|
|
break
|
|
|
|
new_splits = []
|
|
for split in splits:
|
|
split_len = len(self.tokenizer(split))
|
|
if split_len <= chunk_size:
|
|
new_splits.append(split)
|
|
else:
|
|
# recursively split
|
|
new_splits.extend(self._split(split, chunk_size=chunk_size))
|
|
return new_splits
|
|
|
|
def _merge(self, splits: List[str], chunk_size: int) -> List[str]:
|
|
"""Merge splits into chunks.
|
|
|
|
The high-level idea is to keep adding splits to a chunk until we
|
|
exceed the chunk size, then we start a new chunk with overlap.
|
|
|
|
When we start a new chunk, we pop off the first element of the previous
|
|
chunk until the total length is less than the chunk size.
|
|
"""
|
|
chunks: List[str] = []
|
|
|
|
cur_chunk: List[str] = []
|
|
cur_len = 0
|
|
for split in splits:
|
|
split_len = len(self.tokenizer(split))
|
|
if split_len > chunk_size:
|
|
print(
|
|
f"Got a split of size {split_len}, ",
|
|
f"larger than chunk size {chunk_size}.",
|
|
)
|
|
|
|
# if we exceed the chunk size after adding the new split, then
|
|
# we need to end the current chunk and start a new one
|
|
if cur_len + split_len > chunk_size:
|
|
# end the previous chunk
|
|
chunk = "".join(cur_chunk).strip()
|
|
if chunk:
|
|
chunks.append(chunk)
|
|
|
|
# start a new chunk with overlap
|
|
# keep popping off the first element of the previous chunk until:
|
|
# 1. the current chunk length is less than chunk overlap
|
|
# 2. the total length is less than chunk size
|
|
while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size:
|
|
# pop off the first element
|
|
first_chunk = cur_chunk.pop(0)
|
|
cur_len -= len(self.tokenizer(first_chunk))
|
|
|
|
cur_chunk.append(split)
|
|
cur_len += split_len
|
|
|
|
# handle the last chunk
|
|
chunk = "".join(cur_chunk).strip()
|
|
if chunk:
|
|
chunks.append(chunk)
|
|
|
|
return chunks
|