mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 09:37:03 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
0
dbgpt/rag/text_splitter/__init__.py
Normal file
0
dbgpt/rag/text_splitter/__init__.py
Normal file
41
dbgpt/rag/text_splitter/pre_text_splitter.py
Normal file
41
dbgpt/rag/text_splitter/pre_text_splitter.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Iterable, List
|
||||
|
||||
from dbgpt.rag.chunk import Document, Chunk
|
||||
from dbgpt.rag.text_splitter.text_splitter import TextSplitter
|
||||
|
||||
|
||||
def _single_document_split(
|
||||
document: Document, pre_separator: str
|
||||
) -> Iterable[Document]:
|
||||
content = document.content
|
||||
for i, content in enumerate(content.split(pre_separator)):
|
||||
metadata = document.metadata.copy()
|
||||
if "source" in metadata:
|
||||
metadata["source"] = metadata["source"] + "_pre_split_" + str(i)
|
||||
yield Chunk(content=content, metadata=metadata)
|
||||
|
||||
|
||||
class PreTextSplitter(TextSplitter):
|
||||
"""Split text by pre separator"""
|
||||
|
||||
def __init__(self, pre_separator: str, text_splitter_impl: TextSplitter):
|
||||
"""Initialize with Knowledge arguments.
|
||||
Args:
|
||||
pre_separator: pre separator
|
||||
text_splitter_impl: text splitter impl
|
||||
"""
|
||||
self.pre_separator = pre_separator
|
||||
self._impl = text_splitter_impl
|
||||
|
||||
def split_text(self, text: str, **kwargs) -> List[str]:
|
||||
"""Split text by pre separator"""
|
||||
return self._impl.split_text(text)
|
||||
|
||||
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
|
||||
"""Split documents by pre separator"""
|
||||
|
||||
def generator() -> Iterable[Document]:
|
||||
for doc in documents:
|
||||
yield from _single_document_split(doc, pre_separator=self.pre_separator)
|
||||
|
||||
return self._impl.split_documents(generator())
|
0
dbgpt/rag/text_splitter/tests/__init__.py
Normal file
0
dbgpt/rag/text_splitter/tests/__init__.py
Normal file
65
dbgpt/rag/text_splitter/tests/test_splitters.py
Normal file
65
dbgpt/rag/text_splitter/tests/test_splitters.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.text_splitter.text_splitter import (
|
||||
CharacterTextSplitter,
|
||||
MarkdownHeaderTextSplitter,
|
||||
)
|
||||
|
||||
|
||||
def test_md_header_text_splitter() -> None:
|
||||
"""unit test markdown splitter by header"""
|
||||
|
||||
markdown_document = (
|
||||
"# dbgpt\n\n"
|
||||
" ## description\n\n"
|
||||
"my name is dbgpt\n\n"
|
||||
" ## content\n\n"
|
||||
"my name is aries"
|
||||
)
|
||||
headers_to_split_on = [
|
||||
("#", "Header 1"),
|
||||
("##", "Header 2"),
|
||||
]
|
||||
markdown_splitter = MarkdownHeaderTextSplitter(
|
||||
headers_to_split_on=headers_to_split_on,
|
||||
)
|
||||
output = markdown_splitter.split_text(markdown_document)
|
||||
expected_output = [
|
||||
Chunk(
|
||||
content="{'Header 1': 'dbgpt', 'Header 2': 'description'}, my name is dbgpt",
|
||||
metadata={"Header 1": "dbgpt", "Header 2": "description"},
|
||||
),
|
||||
Chunk(
|
||||
content="{'Header 1': 'dbgpt', 'Header 2': 'content'}, my name is aries",
|
||||
metadata={"Header 1": "dbgpt", "Header 2": "content"},
|
||||
),
|
||||
]
|
||||
assert [output.content for output in output] == [
|
||||
output.content for output in expected_output
|
||||
]
|
||||
|
||||
|
||||
def test_merge_splits() -> None:
|
||||
"""Test merging splits with a given separator."""
|
||||
splitter = CharacterTextSplitter(separator=" ", chunk_size=9, chunk_overlap=2)
|
||||
splits = ["foo", "bar", "baz"]
|
||||
expected_output = ["foo bar", "baz"]
|
||||
output = splitter._merge_splits(splits, separator=" ")
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_character_text_splitter() -> None:
|
||||
"""Test splitting by character count."""
|
||||
text = "foo bar baz 123"
|
||||
splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3)
|
||||
output = splitter.split_text(text)
|
||||
expected_output = ["foo bar", "bar baz", "baz 123"]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_character_text_splitter_empty_doc() -> None:
|
||||
"""Test splitting by character count doesn't create empty documents."""
|
||||
text = "db gpt"
|
||||
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
|
||||
output = splitter.split_text(text)
|
||||
expected_output = ["db", "gpt"]
|
||||
assert output == expected_output
|
730
dbgpt/rag/text_splitter/text_splitter.py
Normal file
730
dbgpt/rag/text_splitter/text_splitter.py
Normal file
@@ -0,0 +1,730 @@
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from dbgpt.rag.chunk import Document, Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextSplitter(ABC):
|
||||
"""Interface for splitting text into chunks.
|
||||
Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py
|
||||
"""
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
filters: list = [],
|
||||
separator: str = "",
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||
f"({chunk_size}), should be smaller."
|
||||
)
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._length_function = length_function
|
||||
self._filter = filters
|
||||
self._separator = separator
|
||||
|
||||
@abstractmethod
|
||||
def split_text(self, text: str, **kwargs) -> List[str]:
|
||||
"""Split text into multiple components."""
|
||||
|
||||
def create_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
separator: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[Chunk]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
chunks = []
|
||||
for i, text in enumerate(texts):
|
||||
for chunk in self.split_text(text, separator, **kwargs):
|
||||
new_doc = Chunk(content=chunk, metadata=copy.deepcopy(_metadatas[i]))
|
||||
chunks.append(new_doc)
|
||||
return chunks
|
||||
|
||||
def split_documents(self, documents: List[Document], **kwargs) -> List[Chunk]:
|
||||
"""Split documents."""
|
||||
texts = [doc.content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return self.create_documents(texts, metadatas, **kwargs)
|
||||
|
||||
def _join_docs(self, docs: List[str], separator: str, **kwargs) -> Optional[str]:
|
||||
text = separator.join(docs)
|
||||
text = text.strip()
|
||||
if text == "":
|
||||
return None
|
||||
else:
|
||||
return text
|
||||
|
||||
def _merge_splits(
|
||||
self,
|
||||
splits: Iterable[str],
|
||||
separator: str,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
# We now want to combine these smaller pieces into medium size
|
||||
# chunks to send to the LLM.
|
||||
if chunk_size is None:
|
||||
chunk_size = self._chunk_size
|
||||
if chunk_overlap is None:
|
||||
chunk_overlap = self._chunk_overlap
|
||||
if separator is None:
|
||||
separator = self._separator
|
||||
separator_len = self._length_function(separator)
|
||||
|
||||
docs = []
|
||||
current_doc: List[str] = []
|
||||
total = 0
|
||||
for d in splits:
|
||||
_len = self._length_function(d)
|
||||
if (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> chunk_size
|
||||
):
|
||||
if total > chunk_size:
|
||||
logger.warning(
|
||||
f"Created a chunk of size {total}, "
|
||||
f"which is longer than the specified {chunk_size}"
|
||||
)
|
||||
if len(current_doc) > 0:
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
# Keep on popping if:
|
||||
# - we have a larger chunk than in the chunk overlap
|
||||
# - or if we still have any chunks and the length is long
|
||||
while total > chunk_overlap or (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> chunk_size
|
||||
and total > 0
|
||||
):
|
||||
total -= self._length_function(current_doc[0]) + (
|
||||
separator_len if len(current_doc) > 1 else 0
|
||||
)
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def clean(self, documents: List[dict], filters: List[str]):
|
||||
for special_character in filters:
|
||||
for doc in documents:
|
||||
doc["content"] = doc["content"].replace(special_character, "")
|
||||
return documents
|
||||
|
||||
def run( # type: ignore
|
||||
self,
|
||||
documents: Union[dict, List[dict]],
|
||||
meta: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, # type: ignore
|
||||
separator: Optional[str] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
filters: Optional[List[str]] = None,
|
||||
):
|
||||
if separator is None:
|
||||
separator = self._separator
|
||||
if chunk_size is None:
|
||||
chunk_size = self._chunk_size
|
||||
if chunk_overlap is None:
|
||||
chunk_overlap = self._chunk_overlap
|
||||
if filters is None:
|
||||
filters = self._filter
|
||||
ret = []
|
||||
if type(documents) == dict: # single document
|
||||
text_splits = self.split_text(
|
||||
documents["content"],
|
||||
separator=separator,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
for i, txt in enumerate(text_splits):
|
||||
doc = copy.deepcopy(documents)
|
||||
doc["content"] = txt
|
||||
|
||||
if "meta" not in doc.keys() or doc["meta"] is None:
|
||||
doc["meta"] = {}
|
||||
|
||||
doc["meta"]["_split_id"] = i
|
||||
ret.append(doc)
|
||||
|
||||
elif type(documents) == list: # list document
|
||||
for document in documents:
|
||||
text_splits = self.split_text(
|
||||
document["content"],
|
||||
separator=separator,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
for i, txt in enumerate(text_splits):
|
||||
doc = copy.deepcopy(document)
|
||||
doc["content"] = txt
|
||||
|
||||
if "meta" not in doc.keys() or doc["meta"] is None:
|
||||
doc["meta"] = {}
|
||||
|
||||
doc["meta"]["_split_id"] = i
|
||||
ret.append(doc)
|
||||
if filters is not None and len(filters) > 0:
|
||||
ret = self.clean(ret, filters)
|
||||
result = {"documents": ret}
|
||||
return result, "output_1"
|
||||
|
||||
|
||||
class CharacterTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at characters.
|
||||
Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py
|
||||
"""
|
||||
|
||||
def __init__(self, separator: str = "\n\n", filters: list = [], **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separator = separator
|
||||
self._filter = filters
|
||||
|
||||
def split_text(
|
||||
self, text: str, separator: Optional[str] = None, **kwargs
|
||||
) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
if separator is None:
|
||||
separator = self._separator
|
||||
if separator:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
return self._merge_splits(splits, separator, **kwargs)
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at characters.
|
||||
Recursively tries to split by different characters to find one
|
||||
that works.
|
||||
Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py
|
||||
"""
|
||||
|
||||
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separators = separators or ["###", "\n", " ", ""]
|
||||
|
||||
def split_text(
|
||||
self, text: str, separator: Optional[str] = None, **kwargs
|
||||
) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = self._separators[-1]
|
||||
for _s in self._separators:
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if _s in text:
|
||||
separator = _s
|
||||
break
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
for s in splits:
|
||||
if self._length_function(s) < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(
|
||||
_good_splits,
|
||||
separator,
|
||||
chunk_size=kwargs.get("chunk_size", None),
|
||||
chunk_overlap=kwargs.get("chunk_overlap", None),
|
||||
)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
other_info = self.split_text(s)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(
|
||||
_good_splits,
|
||||
separator,
|
||||
chunk_size=kwargs.get("chunk_size", None),
|
||||
chunk_overlap=kwargs.get("chunk_overlap", None),
|
||||
)
|
||||
final_chunks.extend(merged_text)
|
||||
return final_chunks
|
||||
|
||||
|
||||
class SpacyTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at sentences using Spacy.
|
||||
Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline: str = "zh_core_web_sm", **kwargs: Any) -> None:
|
||||
"""Initialize the spacy text splitter."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import spacy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Spacy is not installed, please install it with `pip install spacy`."
|
||||
)
|
||||
try:
|
||||
self._tokenizer = spacy.load(pipeline)
|
||||
except:
|
||||
spacy.cli.download(pipeline)
|
||||
self._tokenizer = spacy.load(pipeline)
|
||||
|
||||
def split_text(
|
||||
self, text: str, separator: Optional[str] = None, **kwargs
|
||||
) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
if len(text) > 1000000:
|
||||
self._tokenizer.max_length = len(text) + 100
|
||||
splits = (str(s) for s in self._tokenizer(text).sents)
|
||||
return self._merge_splits(splits, separator, **kwargs)
|
||||
|
||||
|
||||
class HeaderType(TypedDict):
|
||||
"""Header type as typed dict."""
|
||||
|
||||
level: int
|
||||
name: str
|
||||
data: str
|
||||
|
||||
|
||||
class LineType(TypedDict):
|
||||
"""Line type as typed dict."""
|
||||
|
||||
metadata: Dict[str, str]
|
||||
content: str
|
||||
|
||||
|
||||
class MarkdownHeaderTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting markdown files based on specified headers.
|
||||
Refer to https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/text_splitter.py
|
||||
"""
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers_to_split_on: List[Tuple[str, str]] = [
|
||||
("#", "Header 1"),
|
||||
("##", "Header 2"),
|
||||
("###", "Header 3"),
|
||||
("####", "Header 4"),
|
||||
("#####", "Header 5"),
|
||||
("######", "Header 6"),
|
||||
],
|
||||
return_each_line: bool = False,
|
||||
filters: list = [],
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
separator="\n",
|
||||
):
|
||||
"""Create a new MarkdownHeaderTextSplitter.
|
||||
|
||||
Args:
|
||||
headers_to_split_on: Headers we want to track
|
||||
return_each_line: Return each line w/ associated headers
|
||||
"""
|
||||
# Output line-by-line or aggregated into chunks w/ common headers
|
||||
self.return_each_line = return_each_line
|
||||
self._chunk_size = chunk_size
|
||||
# Given the headers we want to split on,
|
||||
# (e.g., "#, ##, etc") order by length
|
||||
self.headers_to_split_on = sorted(
|
||||
headers_to_split_on, key=lambda split: len(split[0]), reverse=True
|
||||
)
|
||||
self._filter = filters
|
||||
self._length_function = length_function
|
||||
self._separator = separator
|
||||
self._chunk_overlap = chunk_overlap
|
||||
|
||||
def create_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
separator: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[Chunk]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
chunks = []
|
||||
for i, text in enumerate(texts):
|
||||
for chunk in self.split_text(text, separator, **kwargs):
|
||||
metadata = chunk.metadata or {}
|
||||
metadata.update(_metadatas[i])
|
||||
new_doc = Chunk(content=chunk.content, metadata=metadata)
|
||||
chunks.append(new_doc)
|
||||
return chunks
|
||||
|
||||
def aggregate_lines_to_chunks(self, lines: List[LineType]) -> List[Chunk]:
|
||||
"""Combine lines with common metadata into chunks
|
||||
Args:
|
||||
lines: Line of text / associated header metadata
|
||||
"""
|
||||
aggregated_chunks: List[LineType] = []
|
||||
|
||||
for line in lines:
|
||||
if (
|
||||
aggregated_chunks
|
||||
and aggregated_chunks[-1]["metadata"] == line["metadata"]
|
||||
):
|
||||
# If the last line in the aggregated list
|
||||
# has the same metadata as the current line,
|
||||
# append the current content to the last lines's content
|
||||
aggregated_chunks[-1]["content"] += " \n" + line["content"]
|
||||
else:
|
||||
# Otherwise, append the current line to the aggregated list
|
||||
line["content"] = f"{line['metadata']}, " + line["content"]
|
||||
aggregated_chunks.append(line)
|
||||
|
||||
return [
|
||||
Chunk(content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in aggregated_chunks
|
||||
]
|
||||
|
||||
def split_text(
|
||||
self,
|
||||
text: str,
|
||||
separator: Optional[str] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Split markdown file
|
||||
Args:
|
||||
text: Markdown file"""
|
||||
if separator is None:
|
||||
separator = self._separator
|
||||
if chunk_size is None:
|
||||
chunk_size = self._chunk_size
|
||||
if chunk_overlap is None:
|
||||
chunk_overlap = self._chunk_overlap
|
||||
|
||||
# Split the input text by newline character ("\n").
|
||||
lines = text.split(separator)
|
||||
# Final output
|
||||
lines_with_metadata: List[LineType] = []
|
||||
# Content and metadata of the chunk currently being processed
|
||||
current_content: List[str] = []
|
||||
current_metadata: Dict[str, str] = {}
|
||||
# Keep track of the nested header structure
|
||||
# header_stack: List[Dict[str, Union[int, str]]] = []
|
||||
header_stack: List[HeaderType] = []
|
||||
initial_metadata: Dict[str, str] = {}
|
||||
for line in lines:
|
||||
stripped_line = line.strip()
|
||||
# Check each line against each of the header types (e.g., #, ##)
|
||||
for sep, name in self.headers_to_split_on:
|
||||
# Check if line starts with a header that we intend to split on
|
||||
if stripped_line.startswith(sep) and (
|
||||
# Header with no text OR header is followed by space
|
||||
# Both are valid conditions that sep is being used a header
|
||||
len(stripped_line) == len(sep)
|
||||
or stripped_line[len(sep)] == " "
|
||||
):
|
||||
# Ensure we are tracking the header as metadata
|
||||
if name is not None:
|
||||
# Get the current header level
|
||||
current_header_level = sep.count("#")
|
||||
|
||||
# Pop out headers of lower or same level from the stack
|
||||
while (
|
||||
header_stack
|
||||
and header_stack[-1]["level"] >= current_header_level
|
||||
):
|
||||
# We have encountered a new header
|
||||
# at the same or higher level
|
||||
popped_header = header_stack.pop()
|
||||
# Clear the metadata for the
|
||||
# popped header in initial_metadata
|
||||
if popped_header["name"] in initial_metadata:
|
||||
initial_metadata.pop(popped_header["name"])
|
||||
|
||||
# Push the current header to the stack
|
||||
header: HeaderType = {
|
||||
"level": current_header_level,
|
||||
"name": name,
|
||||
"data": stripped_line[len(sep) :].strip(),
|
||||
}
|
||||
header_stack.append(header)
|
||||
# Update initial_metadata with the current header
|
||||
initial_metadata[name] = header["data"]
|
||||
|
||||
# Add the previous line to the lines_with_metadata
|
||||
# only if current_content is not empty
|
||||
if current_content:
|
||||
lines_with_metadata.append(
|
||||
{
|
||||
"content": separator.join(current_content),
|
||||
"metadata": current_metadata.copy(),
|
||||
}
|
||||
)
|
||||
current_content.clear()
|
||||
|
||||
break
|
||||
else:
|
||||
if stripped_line:
|
||||
current_content.append(stripped_line)
|
||||
elif current_content:
|
||||
lines_with_metadata.append(
|
||||
{
|
||||
"content": separator.join(current_content),
|
||||
"metadata": current_metadata.copy(),
|
||||
}
|
||||
)
|
||||
current_content.clear()
|
||||
|
||||
current_metadata = initial_metadata.copy()
|
||||
if current_content:
|
||||
lines_with_metadata.append(
|
||||
{
|
||||
"content": separator.join(current_content),
|
||||
"metadata": current_metadata,
|
||||
}
|
||||
)
|
||||
# lines_with_metadata has each line with associated header metadata
|
||||
# aggregate these into chunks based on common metadata
|
||||
if not self.return_each_line:
|
||||
return self.aggregate_lines_to_chunks(lines_with_metadata)
|
||||
else:
|
||||
return [
|
||||
Document(content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in lines_with_metadata
|
||||
]
|
||||
|
||||
def clean(self, documents: List[dict], filters: Optional[List[str]] = None):
|
||||
if filters is None:
|
||||
filters = self._filter
|
||||
for special_character in filters:
|
||||
for doc in documents:
|
||||
doc["content"] = doc["content"].replace(special_character, "")
|
||||
return documents
|
||||
|
||||
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
|
||||
text = separator.join(docs)
|
||||
text = text.strip()
|
||||
if text == "":
|
||||
return None
|
||||
else:
|
||||
return text
|
||||
|
||||
def _merge_splits(
|
||||
self,
|
||||
documents: List[dict],
|
||||
separator: Optional[str] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: [int] = None,
|
||||
) -> List[str]:
|
||||
# We now want to combine these smaller pieces into medium size
|
||||
# chunks to send to the LLM.
|
||||
if chunk_size is None:
|
||||
chunk_size = self._chunk_size
|
||||
if chunk_overlap is None:
|
||||
chunk_overlap = self._chunk_overlap
|
||||
if separator is None:
|
||||
separator = self._separator
|
||||
separator_len = self._length_function(separator)
|
||||
|
||||
docs = []
|
||||
current_doc: List[str] = []
|
||||
total = 0
|
||||
for doc in documents:
|
||||
if doc["metadata"] != {}:
|
||||
head = sorted(
|
||||
doc["metadata"].items(), key=lambda x: x[0], reverse=True
|
||||
)[0][1]
|
||||
d = head + separator + doc["page_content"]
|
||||
else:
|
||||
d = doc["page_content"]
|
||||
_len = self._length_function(d)
|
||||
if (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> chunk_size
|
||||
):
|
||||
if total > chunk_size:
|
||||
logger.warning(
|
||||
f"Created a chunk of size {total}, "
|
||||
f"which is longer than the specified {chunk_size}"
|
||||
)
|
||||
if len(current_doc) > 0:
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
# Keep on popping if:
|
||||
# - we have a larger chunk than in the chunk overlap
|
||||
# - or if we still have any chunks and the length is long
|
||||
while total > chunk_overlap or (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> chunk_size
|
||||
and total > 0
|
||||
):
|
||||
total -= self._length_function(current_doc[0]) + (
|
||||
separator_len if len(current_doc) > 1 else 0
|
||||
)
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def run(
|
||||
self,
|
||||
documents: Union[dict, List[dict]],
|
||||
meta: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,
|
||||
filters: Optional[List[str]] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
separator: Optional[str] = None,
|
||||
):
|
||||
if filters is None:
|
||||
filters = self._filter
|
||||
if chunk_size is None:
|
||||
chunk_size = self._chunk_size
|
||||
if chunk_overlap is None:
|
||||
chunk_overlap = self._chunk_overlap
|
||||
if separator is None:
|
||||
separator = self._separator
|
||||
ret = []
|
||||
if type(documents) == list:
|
||||
for document in documents:
|
||||
text_splits = self.split_text(
|
||||
document["content"], separator, chunk_size, chunk_overlap
|
||||
)
|
||||
for i, txt in enumerate(text_splits):
|
||||
doc = {}
|
||||
doc["content"] = txt
|
||||
|
||||
if "meta" not in doc.keys() or doc["meta"] is None:
|
||||
doc["meta"] = {}
|
||||
|
||||
doc["meta"]["_split_id"] = i
|
||||
ret.append(doc)
|
||||
elif type(documents) == dict:
|
||||
text_splits = self.split_text(
|
||||
documents["content"], separator, chunk_size, chunk_overlap
|
||||
)
|
||||
for i, txt in enumerate(text_splits):
|
||||
doc = {}
|
||||
doc["content"] = txt
|
||||
|
||||
if "meta" not in doc.keys() or doc["meta"] is None:
|
||||
doc["meta"] = {}
|
||||
|
||||
doc["meta"]["_split_id"] = i
|
||||
ret.append(doc)
|
||||
if filters is None:
|
||||
filters = self._filter
|
||||
if filters is not None and len(filters) > 0:
|
||||
ret = self.clean(ret, filters)
|
||||
result = {"documents": ret}
|
||||
return result, "output_1"
|
||||
|
||||
|
||||
class ParagraphTextSplitter(CharacterTextSplitter):
|
||||
"""Implementation of splitting text that looks at paragraphs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
separator="\n",
|
||||
chunk_size: Optional[int] = 0,
|
||||
chunk_overlap: Optional[int] = 0,
|
||||
):
|
||||
self._separator = separator
|
||||
if self._separator is None:
|
||||
self._separator = "\n"
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._is_paragraph = chunk_overlap
|
||||
|
||||
def split_text(
|
||||
self, text: str, separator: Optional[str] = "\n", **kwargs
|
||||
) -> List[str]:
|
||||
paragraphs = text.strip().split(self._separator)
|
||||
paragraphs = [p.strip() for p in paragraphs if p.strip() != ""]
|
||||
return paragraphs
|
||||
|
||||
|
||||
class SeparatorTextSplitter(CharacterTextSplitter):
|
||||
"""SeparatorTextSplitter"""
|
||||
|
||||
def __init__(self, separator: str = "\n", filters: list = [], **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separator = separator
|
||||
self._filter = filters
|
||||
|
||||
def split_text(
|
||||
self, text: str, separator: Optional[str] = None, **kwargs
|
||||
) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
if separator is None:
|
||||
separator = self._separator
|
||||
if separator:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
return self._merge_splits(splits, separator, chunk_overlap=0, **kwargs)
|
||||
|
||||
|
||||
class PageTextSplitter(TextSplitter):
|
||||
"""PageTextSplitter"""
|
||||
|
||||
def __init__(self, separator: str = "\n\n", filters: list = [], **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separator = separator
|
||||
self._filter = filters
|
||||
|
||||
def split_text(
|
||||
self, text: str, separator: Optional[str] = None, **kwargs
|
||||
) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
return text
|
||||
|
||||
def create_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
separator: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[Chunk]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
chunks = []
|
||||
for i, text in enumerate(texts):
|
||||
new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))
|
||||
chunks.append(new_doc)
|
||||
return chunks
|
183
dbgpt/rag/text_splitter/token_splitter.py
Normal file
183
dbgpt/rag/text_splitter/token_splitter.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Token splitter."""
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
|
||||
|
||||
from dbgpt.util.global_helper import globals_helper
|
||||
from dbgpt.util.splitter_utils import split_by_sep, split_by_char
|
||||
|
||||
DEFAULT_METADATA_FORMAT_LEN = 2
|
||||
DEFAULT_CHUNK_OVERLAP = 20
|
||||
DEFAULT_CHUNK_SIZE = 1024
|
||||
|
||||
|
||||
class TokenTextSplitter(BaseModel):
|
||||
"""Implementation of splitting text that looks at word tokens."""
|
||||
|
||||
chunk_size: int = Field(
|
||||
default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk."
|
||||
)
|
||||
chunk_overlap: int = Field(
|
||||
default=DEFAULT_CHUNK_OVERLAP,
|
||||
description="The token overlap of each chunk when splitting.",
|
||||
)
|
||||
separator: str = Field(
|
||||
default=" ", description="Default separator for splitting into words"
|
||||
)
|
||||
backup_separators: List = Field(
|
||||
default_factory=list, description="Additional separators for splitting."
|
||||
)
|
||||
# callback_manager: CallbackManager = Field(
|
||||
# default_factory=CallbackManager, exclude=True
|
||||
# )
|
||||
tokenizer: Callable = Field(
|
||||
default_factory=globals_helper.tokenizer, # type: ignore
|
||||
description="Tokenizer for splitting words into tokens.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
_split_fns: List[Callable] = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||
tokenizer: Optional[Callable] = None,
|
||||
# callback_manager: Optional[CallbackManager] = None,
|
||||
separator: str = " ",
|
||||
backup_separators: Optional[List[str]] = ["\n"],
|
||||
):
|
||||
"""Initialize with parameters."""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||
f"({chunk_size}), should be smaller."
|
||||
)
|
||||
# callback_manager = callback_manager or CallbackManager([])
|
||||
tokenizer = tokenizer or globals_helper.tokenizer
|
||||
|
||||
all_seps = [separator] + (backup_separators or [])
|
||||
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
|
||||
|
||||
super().__init__(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separator=separator,
|
||||
backup_separators=backup_separators,
|
||||
# callback_manager=callback_manager,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "TokenTextSplitter"
|
||||
|
||||
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
|
||||
"""Split text into chunks, reserving space required for metadata str."""
|
||||
metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN
|
||||
effective_chunk_size = self.chunk_size - metadata_len
|
||||
if effective_chunk_size <= 0:
|
||||
raise ValueError(
|
||||
f"Metadata length ({metadata_len}) is longer than chunk size "
|
||||
f"({self.chunk_size}). Consider increasing the chunk size or "
|
||||
"decreasing the size of your metadata to avoid this."
|
||||
)
|
||||
elif effective_chunk_size < 50:
|
||||
print(
|
||||
f"Metadata length ({metadata_len}) is close to chunk size "
|
||||
f"({self.chunk_size}). Resulting chunks are less than 50 tokens. "
|
||||
"Consider increasing the chunk size or decreasing the size of "
|
||||
"your metadata to avoid this.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return self._split_text(text, chunk_size=effective_chunk_size)
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split text into chunks."""
|
||||
return self._split_text(text, chunk_size=self.chunk_size)
|
||||
|
||||
def _split_text(self, text: str, chunk_size: int) -> List[str]:
|
||||
"""Split text into chunks up to chunk_size."""
|
||||
if text == "":
|
||||
return []
|
||||
|
||||
splits = self._split(text, chunk_size)
|
||||
chunks = self._merge(splits, chunk_size)
|
||||
return chunks
|
||||
|
||||
def _split(self, text: str, chunk_size: int) -> List[str]:
|
||||
"""Break text into splits that are smaller than chunk size.
|
||||
|
||||
The order of splitting is:
|
||||
1. split by separator
|
||||
2. split by backup separators (if any)
|
||||
3. split by characters
|
||||
|
||||
NOTE: the splits contain the separators.
|
||||
"""
|
||||
if len(self.tokenizer(text)) <= chunk_size:
|
||||
return [text]
|
||||
|
||||
for split_fn in self._split_fns:
|
||||
splits = split_fn(text)
|
||||
if len(splits) > 1:
|
||||
break
|
||||
|
||||
new_splits = []
|
||||
for split in splits:
|
||||
split_len = len(self.tokenizer(split))
|
||||
if split_len <= chunk_size:
|
||||
new_splits.append(split)
|
||||
else:
|
||||
# recursively split
|
||||
new_splits.extend(self._split(split, chunk_size=chunk_size))
|
||||
return new_splits
|
||||
|
||||
def _merge(self, splits: List[str], chunk_size: int) -> List[str]:
|
||||
"""Merge splits into chunks.
|
||||
|
||||
The high-level idea is to keep adding splits to a chunk until we
|
||||
exceed the chunk size, then we start a new chunk with overlap.
|
||||
|
||||
When we start a new chunk, we pop off the first element of the previous
|
||||
chunk until the total length is less than the chunk size.
|
||||
"""
|
||||
chunks: List[str] = []
|
||||
|
||||
cur_chunk: List[str] = []
|
||||
cur_len = 0
|
||||
for split in splits:
|
||||
split_len = len(self.tokenizer(split))
|
||||
if split_len > chunk_size:
|
||||
print(
|
||||
f"Got a split of size {split_len}, ",
|
||||
f"larger than chunk size {chunk_size}.",
|
||||
)
|
||||
|
||||
# if we exceed the chunk size after adding the new split, then
|
||||
# we need to end the current chunk and start a new one
|
||||
if cur_len + split_len > chunk_size:
|
||||
# end the previous chunk
|
||||
chunk = "".join(cur_chunk).strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
# start a new chunk with overlap
|
||||
# keep popping off the first element of the previous chunk until:
|
||||
# 1. the current chunk length is less than chunk overlap
|
||||
# 2. the total length is less than chunk size
|
||||
while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size:
|
||||
# pop off the first element
|
||||
first_chunk = cur_chunk.pop(0)
|
||||
cur_len -= len(self.tokenizer(first_chunk))
|
||||
|
||||
cur_chunk.append(split)
|
||||
cur_len += split_len
|
||||
|
||||
# handle the last chunk
|
||||
chunk = "".join(cur_chunk).strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
Reference in New Issue
Block a user