refactor: RAG Refactor (#985)

Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-03 09:45:26 +08:00
committed by GitHub
parent 90775aad50
commit 9ad70a2961
206 changed files with 5766 additions and 2419 deletions

View File

View File

@@ -0,0 +1,41 @@
from typing import Iterable, List
from dbgpt.rag.chunk import Document, Chunk
from dbgpt.rag.text_splitter.text_splitter import TextSplitter
def _single_document_split(
document: Document, pre_separator: str
) -> Iterable[Document]:
content = document.content
for i, content in enumerate(content.split(pre_separator)):
metadata = document.metadata.copy()
if "source" in metadata:
metadata["source"] = metadata["source"] + "_pre_split_" + str(i)
yield Chunk(content=content, metadata=metadata)
class PreTextSplitter(TextSplitter):
"""Split text by pre separator"""
def __init__(self, pre_separator: str, text_splitter_impl: TextSplitter):
"""Initialize with Knowledge arguments.
Args:
pre_separator: pre separator
text_splitter_impl: text splitter impl
"""
self.pre_separator = pre_separator
self._impl = text_splitter_impl
def split_text(self, text: str, **kwargs) -> List[str]:
"""Split text by pre separator"""
return self._impl.split_text(text)
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
"""Split documents by pre separator"""
def generator() -> Iterable[Document]:
for doc in documents:
yield from _single_document_split(doc, pre_separator=self.pre_separator)
return self._impl.split_documents(generator())

View File

@@ -0,0 +1,65 @@
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.text_splitter.text_splitter import (
CharacterTextSplitter,
MarkdownHeaderTextSplitter,
)
def test_md_header_text_splitter() -> None:
"""unit test markdown splitter by header"""
markdown_document = (
"# dbgpt\n\n"
" ## description\n\n"
"my name is dbgpt\n\n"
" ## content\n\n"
"my name is aries"
)
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
]
markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on,
)
output = markdown_splitter.split_text(markdown_document)
expected_output = [
Chunk(
content="{'Header 1': 'dbgpt', 'Header 2': 'description'}, my name is dbgpt",
metadata={"Header 1": "dbgpt", "Header 2": "description"},
),
Chunk(
content="{'Header 1': 'dbgpt', 'Header 2': 'content'}, my name is aries",
metadata={"Header 1": "dbgpt", "Header 2": "content"},
),
]
assert [output.content for output in output] == [
output.content for output in expected_output
]
def test_merge_splits() -> None:
"""Test merging splits with a given separator."""
splitter = CharacterTextSplitter(separator=" ", chunk_size=9, chunk_overlap=2)
splits = ["foo", "bar", "baz"]
expected_output = ["foo bar", "baz"]
output = splitter._merge_splits(splits, separator=" ")
assert output == expected_output
def test_character_text_splitter() -> None:
"""Test splitting by character count."""
text = "foo bar baz 123"
splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3)
output = splitter.split_text(text)
expected_output = ["foo bar", "bar baz", "baz 123"]
assert output == expected_output
def test_character_text_splitter_empty_doc() -> None:
"""Test splitting by character count doesn't create empty documents."""
text = "db gpt"
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
output = splitter.split_text(text)
expected_output = ["db", "gpt"]
assert output == expected_output

View File

@@ -0,0 +1,730 @@
import copy
import logging
import re
from abc import abstractmethod, ABC
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
TypedDict,
Union,
)
from dbgpt.rag.chunk import Document, Chunk
logger = logging.getLogger(__name__)
class TextSplitter(ABC):
"""Interface for splitting text into chunks.
Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py
"""
outgoing_edges = 1
def __init__(
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
filters: list = [],
separator: str = "",
):
"""Create a new TextSplitter."""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._filter = filters
self._separator = separator
@abstractmethod
def split_text(self, text: str, **kwargs) -> List[str]:
"""Split text into multiple components."""
def create_documents(
self,
texts: List[str],
metadatas: Optional[List[dict]] = None,
separator: Optional[str] = None,
**kwargs,
) -> List[Chunk]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
chunks = []
for i, text in enumerate(texts):
for chunk in self.split_text(text, separator, **kwargs):
new_doc = Chunk(content=chunk, metadata=copy.deepcopy(_metadatas[i]))
chunks.append(new_doc)
return chunks
def split_documents(self, documents: List[Document], **kwargs) -> List[Chunk]:
"""Split documents."""
texts = [doc.content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return self.create_documents(texts, metadatas, **kwargs)
def _join_docs(self, docs: List[str], separator: str, **kwargs) -> Optional[str]:
text = separator.join(docs)
text = text.strip()
if text == "":
return None
else:
return text
def _merge_splits(
self,
splits: Iterable[str],
separator: str,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
if chunk_size is None:
chunk_size = self._chunk_size
if chunk_overlap is None:
chunk_overlap = self._chunk_overlap
if separator is None:
separator = self._separator
separator_len = self._length_function(separator)
docs = []
current_doc: List[str] = []
total = 0
for d in splits:
_len = self._length_function(d)
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> chunk_size
):
if total > chunk_size:
logger.warning(
f"Created a chunk of size {total}, "
f"which is longer than the specified {chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> chunk_size
and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
return docs
def clean(self, documents: List[dict], filters: List[str]):
for special_character in filters:
for doc in documents:
doc["content"] = doc["content"].replace(special_character, "")
return documents
def run( # type: ignore
self,
documents: Union[dict, List[dict]],
meta: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, # type: ignore
separator: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
filters: Optional[List[str]] = None,
):
if separator is None:
separator = self._separator
if chunk_size is None:
chunk_size = self._chunk_size
if chunk_overlap is None:
chunk_overlap = self._chunk_overlap
if filters is None:
filters = self._filter
ret = []
if type(documents) == dict: # single document
text_splits = self.split_text(
documents["content"],
separator=separator,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
for i, txt in enumerate(text_splits):
doc = copy.deepcopy(documents)
doc["content"] = txt
if "meta" not in doc.keys() or doc["meta"] is None:
doc["meta"] = {}
doc["meta"]["_split_id"] = i
ret.append(doc)
elif type(documents) == list: # list document
for document in documents:
text_splits = self.split_text(
document["content"],
separator=separator,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
for i, txt in enumerate(text_splits):
doc = copy.deepcopy(document)
doc["content"] = txt
if "meta" not in doc.keys() or doc["meta"] is None:
doc["meta"] = {}
doc["meta"]["_split_id"] = i
ret.append(doc)
if filters is not None and len(filters) > 0:
ret = self.clean(ret, filters)
result = {"documents": ret}
return result, "output_1"
class CharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters.
Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py
"""
def __init__(self, separator: str = "\n\n", filters: list = [], **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
self._filter = filters
def split_text(
self, text: str, separator: Optional[str] = None, **kwargs
) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
if separator is None:
separator = self._separator
if separator:
splits = text.split(separator)
else:
splits = list(text)
return self._merge_splits(splits, separator, **kwargs)
class RecursiveCharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters.
Recursively tries to split by different characters to find one
that works.
Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py
"""
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separators = separators or ["###", "\n", " ", ""]
def split_text(
self, text: str, separator: Optional[str] = None, **kwargs
) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = self._separators[-1]
for _s in self._separators:
if _s == "":
separator = _s
break
if _s in text:
separator = _s
break
# Now that we have the separator, split the text
if separator:
splits = text.split(separator)
else:
splits = list(text)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(
_good_splits,
separator,
chunk_size=kwargs.get("chunk_size", None),
chunk_overlap=kwargs.get("chunk_overlap", None),
)
final_chunks.extend(merged_text)
_good_splits = []
other_info = self.split_text(s)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(
_good_splits,
separator,
chunk_size=kwargs.get("chunk_size", None),
chunk_overlap=kwargs.get("chunk_overlap", None),
)
final_chunks.extend(merged_text)
return final_chunks
class SpacyTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using Spacy.
Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py
"""
def __init__(self, pipeline: str = "zh_core_web_sm", **kwargs: Any) -> None:
"""Initialize the spacy text splitter."""
super().__init__(**kwargs)
try:
import spacy
except ImportError:
raise ImportError(
"Spacy is not installed, please install it with `pip install spacy`."
)
try:
self._tokenizer = spacy.load(pipeline)
except:
spacy.cli.download(pipeline)
self._tokenizer = spacy.load(pipeline)
def split_text(
self, text: str, separator: Optional[str] = None, **kwargs
) -> List[str]:
"""Split incoming text and return chunks."""
if len(text) > 1000000:
self._tokenizer.max_length = len(text) + 100
splits = (str(s) for s in self._tokenizer(text).sents)
return self._merge_splits(splits, separator, **kwargs)
class HeaderType(TypedDict):
"""Header type as typed dict."""
level: int
name: str
data: str
class LineType(TypedDict):
"""Line type as typed dict."""
metadata: Dict[str, str]
content: str
class MarkdownHeaderTextSplitter(TextSplitter):
"""Implementation of splitting markdown files based on specified headers.
Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py
"""
outgoing_edges = 1
def __init__(
self,
headers_to_split_on: List[Tuple[str, str]] = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
("#####", "Header 5"),
("######", "Header 6"),
],
return_each_line: bool = False,
filters: list = [],
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
separator="\n",
):
"""Create a new MarkdownHeaderTextSplitter.
Args:
headers_to_split_on: Headers we want to track
return_each_line: Return each line w/ associated headers
"""
# Output line-by-line or aggregated into chunks w/ common headers
self.return_each_line = return_each_line
self._chunk_size = chunk_size
# Given the headers we want to split on,
# (e.g., "#, ##, etc") order by length
self.headers_to_split_on = sorted(
headers_to_split_on, key=lambda split: len(split[0]), reverse=True
)
self._filter = filters
self._length_function = length_function
self._separator = separator
self._chunk_overlap = chunk_overlap
def create_documents(
self,
texts: List[str],
metadatas: Optional[List[dict]] = None,
separator: Optional[str] = None,
**kwargs,
) -> List[Chunk]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
chunks = []
for i, text in enumerate(texts):
for chunk in self.split_text(text, separator, **kwargs):
metadata = chunk.metadata or {}
metadata.update(_metadatas[i])
new_doc = Chunk(content=chunk.content, metadata=metadata)
chunks.append(new_doc)
return chunks
def aggregate_lines_to_chunks(self, lines: List[LineType]) -> List[Chunk]:
"""Combine lines with common metadata into chunks
Args:
lines: Line of text / associated header metadata
"""
aggregated_chunks: List[LineType] = []
for line in lines:
if (
aggregated_chunks
and aggregated_chunks[-1]["metadata"] == line["metadata"]
):
# If the last line in the aggregated list
# has the same metadata as the current line,
# append the current content to the last lines's content
aggregated_chunks[-1]["content"] += " \n" + line["content"]
else:
# Otherwise, append the current line to the aggregated list
line["content"] = f"{line['metadata']}, " + line["content"]
aggregated_chunks.append(line)
return [
Chunk(content=chunk["content"], metadata=chunk["metadata"])
for chunk in aggregated_chunks
]
def split_text(
self,
text: str,
separator: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
) -> List[Chunk]:
"""Split markdown file
Args:
text: Markdown file"""
if separator is None:
separator = self._separator
if chunk_size is None:
chunk_size = self._chunk_size
if chunk_overlap is None:
chunk_overlap = self._chunk_overlap
# Split the input text by newline character ("\n").
lines = text.split(separator)
# Final output
lines_with_metadata: List[LineType] = []
# Content and metadata of the chunk currently being processed
current_content: List[str] = []
current_metadata: Dict[str, str] = {}
# Keep track of the nested header structure
# header_stack: List[Dict[str, Union[int, str]]] = []
header_stack: List[HeaderType] = []
initial_metadata: Dict[str, str] = {}
for line in lines:
stripped_line = line.strip()
# Check each line against each of the header types (e.g., #, ##)
for sep, name in self.headers_to_split_on:
# Check if line starts with a header that we intend to split on
if stripped_line.startswith(sep) and (
# Header with no text OR header is followed by space
# Both are valid conditions that sep is being used a header
len(stripped_line) == len(sep)
or stripped_line[len(sep)] == " "
):
# Ensure we are tracking the header as metadata
if name is not None:
# Get the current header level
current_header_level = sep.count("#")
# Pop out headers of lower or same level from the stack
while (
header_stack
and header_stack[-1]["level"] >= current_header_level
):
# We have encountered a new header
# at the same or higher level
popped_header = header_stack.pop()
# Clear the metadata for the
# popped header in initial_metadata
if popped_header["name"] in initial_metadata:
initial_metadata.pop(popped_header["name"])
# Push the current header to the stack
header: HeaderType = {
"level": current_header_level,
"name": name,
"data": stripped_line[len(sep) :].strip(),
}
header_stack.append(header)
# Update initial_metadata with the current header
initial_metadata[name] = header["data"]
# Add the previous line to the lines_with_metadata
# only if current_content is not empty
if current_content:
lines_with_metadata.append(
{
"content": separator.join(current_content),
"metadata": current_metadata.copy(),
}
)
current_content.clear()
break
else:
if stripped_line:
current_content.append(stripped_line)
elif current_content:
lines_with_metadata.append(
{
"content": separator.join(current_content),
"metadata": current_metadata.copy(),
}
)
current_content.clear()
current_metadata = initial_metadata.copy()
if current_content:
lines_with_metadata.append(
{
"content": separator.join(current_content),
"metadata": current_metadata,
}
)
# lines_with_metadata has each line with associated header metadata
# aggregate these into chunks based on common metadata
if not self.return_each_line:
return self.aggregate_lines_to_chunks(lines_with_metadata)
else:
return [
Document(content=chunk["content"], metadata=chunk["metadata"])
for chunk in lines_with_metadata
]
def clean(self, documents: List[dict], filters: Optional[List[str]] = None):
if filters is None:
filters = self._filter
for special_character in filters:
for doc in documents:
doc["content"] = doc["content"].replace(special_character, "")
return documents
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
text = separator.join(docs)
text = text.strip()
if text == "":
return None
else:
return text
def _merge_splits(
self,
documents: List[dict],
separator: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_overlap: [int] = None,
) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
if chunk_size is None:
chunk_size = self._chunk_size
if chunk_overlap is None:
chunk_overlap = self._chunk_overlap
if separator is None:
separator = self._separator
separator_len = self._length_function(separator)
docs = []
current_doc: List[str] = []
total = 0
for doc in documents:
if doc["metadata"] != {}:
head = sorted(
doc["metadata"].items(), key=lambda x: x[0], reverse=True
)[0][1]
d = head + separator + doc["page_content"]
else:
d = doc["page_content"]
_len = self._length_function(d)
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> chunk_size
):
if total > chunk_size:
logger.warning(
f"Created a chunk of size {total}, "
f"which is longer than the specified {chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> chunk_size
and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
return docs
def run(
self,
documents: Union[dict, List[dict]],
meta: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,
filters: Optional[List[str]] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
separator: Optional[str] = None,
):
if filters is None:
filters = self._filter
if chunk_size is None:
chunk_size = self._chunk_size
if chunk_overlap is None:
chunk_overlap = self._chunk_overlap
if separator is None:
separator = self._separator
ret = []
if type(documents) == list:
for document in documents:
text_splits = self.split_text(
document["content"], separator, chunk_size, chunk_overlap
)
for i, txt in enumerate(text_splits):
doc = {}
doc["content"] = txt
if "meta" not in doc.keys() or doc["meta"] is None:
doc["meta"] = {}
doc["meta"]["_split_id"] = i
ret.append(doc)
elif type(documents) == dict:
text_splits = self.split_text(
documents["content"], separator, chunk_size, chunk_overlap
)
for i, txt in enumerate(text_splits):
doc = {}
doc["content"] = txt
if "meta" not in doc.keys() or doc["meta"] is None:
doc["meta"] = {}
doc["meta"]["_split_id"] = i
ret.append(doc)
if filters is None:
filters = self._filter
if filters is not None and len(filters) > 0:
ret = self.clean(ret, filters)
result = {"documents": ret}
return result, "output_1"
class ParagraphTextSplitter(CharacterTextSplitter):
"""Implementation of splitting text that looks at paragraphs."""
def __init__(
self,
separator="\n",
chunk_size: Optional[int] = 0,
chunk_overlap: Optional[int] = 0,
):
self._separator = separator
if self._separator is None:
self._separator = "\n"
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._is_paragraph = chunk_overlap
def split_text(
self, text: str, separator: Optional[str] = "\n", **kwargs
) -> List[str]:
paragraphs = text.strip().split(self._separator)
paragraphs = [p.strip() for p in paragraphs if p.strip() != ""]
return paragraphs
class SeparatorTextSplitter(CharacterTextSplitter):
"""SeparatorTextSplitter"""
def __init__(self, separator: str = "\n", filters: list = [], **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
self._filter = filters
def split_text(
self, text: str, separator: Optional[str] = None, **kwargs
) -> List[str]:
"""Split incoming text and return chunks."""
if separator is None:
separator = self._separator
if separator:
splits = text.split(separator)
else:
splits = list(text)
return self._merge_splits(splits, separator, chunk_overlap=0, **kwargs)
class PageTextSplitter(TextSplitter):
"""PageTextSplitter"""
def __init__(self, separator: str = "\n\n", filters: list = [], **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
self._filter = filters
def split_text(
self, text: str, separator: Optional[str] = None, **kwargs
) -> List[str]:
"""Split incoming text and return chunks."""
return text
def create_documents(
self,
texts: List[str],
metadatas: Optional[List[dict]] = None,
separator: Optional[str] = None,
**kwargs,
) -> List[Chunk]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
chunks = []
for i, text in enumerate(texts):
new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))
chunks.append(new_doc)
return chunks

View File

@@ -0,0 +1,183 @@
"""Token splitter."""
from typing import Callable, List, Optional
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
from dbgpt.util.global_helper import globals_helper
from dbgpt.util.splitter_utils import split_by_sep, split_by_char
DEFAULT_METADATA_FORMAT_LEN = 2
DEFAULT_CHUNK_OVERLAP = 20
DEFAULT_CHUNK_SIZE = 1024
class TokenTextSplitter(BaseModel):
"""Implementation of splitting text that looks at word tokens."""
chunk_size: int = Field(
default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk."
)
chunk_overlap: int = Field(
default=DEFAULT_CHUNK_OVERLAP,
description="The token overlap of each chunk when splitting.",
)
separator: str = Field(
default=" ", description="Default separator for splitting into words"
)
backup_separators: List = Field(
default_factory=list, description="Additional separators for splitting."
)
# callback_manager: CallbackManager = Field(
# default_factory=CallbackManager, exclude=True
# )
tokenizer: Callable = Field(
default_factory=globals_helper.tokenizer, # type: ignore
description="Tokenizer for splitting words into tokens.",
exclude=True,
)
_split_fns: List[Callable] = PrivateAttr()
def __init__(
self,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
tokenizer: Optional[Callable] = None,
# callback_manager: Optional[CallbackManager] = None,
separator: str = " ",
backup_separators: Optional[List[str]] = ["\n"],
):
"""Initialize with parameters."""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
# callback_manager = callback_manager or CallbackManager([])
tokenizer = tokenizer or globals_helper.tokenizer
all_seps = [separator] + (backup_separators or [])
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=separator,
backup_separators=backup_separators,
# callback_manager=callback_manager,
tokenizer=tokenizer,
)
@classmethod
def class_name(cls) -> str:
return "TokenTextSplitter"
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
"""Split text into chunks, reserving space required for metadata str."""
metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN
effective_chunk_size = self.chunk_size - metadata_len
if effective_chunk_size <= 0:
raise ValueError(
f"Metadata length ({metadata_len}) is longer than chunk size "
f"({self.chunk_size}). Consider increasing the chunk size or "
"decreasing the size of your metadata to avoid this."
)
elif effective_chunk_size < 50:
print(
f"Metadata length ({metadata_len}) is close to chunk size "
f"({self.chunk_size}). Resulting chunks are less than 50 tokens. "
"Consider increasing the chunk size or decreasing the size of "
"your metadata to avoid this.",
flush=True,
)
return self._split_text(text, chunk_size=effective_chunk_size)
def split_text(self, text: str) -> List[str]:
"""Split text into chunks."""
return self._split_text(text, chunk_size=self.chunk_size)
def _split_text(self, text: str, chunk_size: int) -> List[str]:
"""Split text into chunks up to chunk_size."""
if text == "":
return []
splits = self._split(text, chunk_size)
chunks = self._merge(splits, chunk_size)
return chunks
def _split(self, text: str, chunk_size: int) -> List[str]:
"""Break text into splits that are smaller than chunk size.
The order of splitting is:
1. split by separator
2. split by backup separators (if any)
3. split by characters
NOTE: the splits contain the separators.
"""
if len(self.tokenizer(text)) <= chunk_size:
return [text]
for split_fn in self._split_fns:
splits = split_fn(text)
if len(splits) > 1:
break
new_splits = []
for split in splits:
split_len = len(self.tokenizer(split))
if split_len <= chunk_size:
new_splits.append(split)
else:
# recursively split
new_splits.extend(self._split(split, chunk_size=chunk_size))
return new_splits
def _merge(self, splits: List[str], chunk_size: int) -> List[str]:
"""Merge splits into chunks.
The high-level idea is to keep adding splits to a chunk until we
exceed the chunk size, then we start a new chunk with overlap.
When we start a new chunk, we pop off the first element of the previous
chunk until the total length is less than the chunk size.
"""
chunks: List[str] = []
cur_chunk: List[str] = []
cur_len = 0
for split in splits:
split_len = len(self.tokenizer(split))
if split_len > chunk_size:
print(
f"Got a split of size {split_len}, ",
f"larger than chunk size {chunk_size}.",
)
# if we exceed the chunk size after adding the new split, then
# we need to end the current chunk and start a new one
if cur_len + split_len > chunk_size:
# end the previous chunk
chunk = "".join(cur_chunk).strip()
if chunk:
chunks.append(chunk)
# start a new chunk with overlap
# keep popping off the first element of the previous chunk until:
# 1. the current chunk length is less than chunk overlap
# 2. the total length is less than chunk size
while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size:
# pop off the first element
first_chunk = cur_chunk.pop(0)
cur_len -= len(self.tokenizer(first_chunk))
cur_chunk.append(split)
cur_len += split_len
# handle the last chunk
chunk = "".join(cur_chunk).strip()
if chunk:
chunks.append(chunk)
return chunks