mirror of
https://github.com/hwchase17/langchain.git
synced 2025-10-24 12:17:47 +00:00
367 lines
13 KiB
Python
367 lines
13 KiB
Python
"""Text splitter base interface."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Literal,
|
|
Optional,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from langchain_core.documents import BaseDocumentTransformer, Document
|
|
from typing_extensions import Self, override
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Collection, Iterable, Sequence
|
|
from collections.abc import Set as AbstractSet
|
|
|
|
|
|
try:
|
|
import tiktoken
|
|
|
|
_HAS_TIKTOKEN = True
|
|
except ImportError:
|
|
_HAS_TIKTOKEN = False
|
|
|
|
try:
|
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
|
|
_HAS_TRANSFORMERS = True
|
|
except ImportError:
|
|
_HAS_TRANSFORMERS = False
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TS = TypeVar("TS", bound="TextSplitter")
|
|
|
|
|
|
class TextSplitter(BaseDocumentTransformer, ABC):
|
|
"""Interface for splitting text into chunks."""
|
|
|
|
def __init__(
|
|
self,
|
|
chunk_size: int = 4000,
|
|
chunk_overlap: int = 200,
|
|
length_function: Callable[[str], int] = len,
|
|
keep_separator: Union[bool, Literal["start", "end"]] = False, # noqa: FBT001,FBT002
|
|
add_start_index: bool = False, # noqa: FBT001,FBT002
|
|
strip_whitespace: bool = True, # noqa: FBT001,FBT002
|
|
) -> None:
|
|
"""Create a new TextSplitter.
|
|
|
|
Args:
|
|
chunk_size: Maximum size of chunks to return
|
|
chunk_overlap: Overlap in characters between chunks
|
|
length_function: Function that measures the length of given chunks
|
|
keep_separator: Whether to keep the separator and where to place it
|
|
in each corresponding chunk (True='start')
|
|
add_start_index: If `True`, includes chunk's start index in metadata
|
|
strip_whitespace: If `True`, strips whitespace from the start and end of
|
|
every document
|
|
"""
|
|
if chunk_size <= 0:
|
|
msg = f"chunk_size must be > 0, got {chunk_size}"
|
|
raise ValueError(msg)
|
|
if chunk_overlap < 0:
|
|
msg = f"chunk_overlap must be >= 0, got {chunk_overlap}"
|
|
raise ValueError(msg)
|
|
if chunk_overlap > chunk_size:
|
|
msg = (
|
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
|
f"({chunk_size}), should be smaller."
|
|
)
|
|
raise ValueError(msg)
|
|
self._chunk_size = chunk_size
|
|
self._chunk_overlap = chunk_overlap
|
|
self._length_function = length_function
|
|
self._keep_separator = keep_separator
|
|
self._add_start_index = add_start_index
|
|
self._strip_whitespace = strip_whitespace
|
|
|
|
@abstractmethod
|
|
def split_text(self, text: str) -> list[str]:
|
|
"""Split text into multiple components."""
|
|
|
|
def create_documents(
|
|
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
|
|
) -> list[Document]:
|
|
"""Create documents from a list of texts."""
|
|
metadatas_ = metadatas or [{}] * len(texts)
|
|
documents = []
|
|
for i, text in enumerate(texts):
|
|
index = 0
|
|
previous_chunk_len = 0
|
|
for chunk in self.split_text(text):
|
|
metadata = copy.deepcopy(metadatas_[i])
|
|
if self._add_start_index:
|
|
offset = index + previous_chunk_len - self._chunk_overlap
|
|
index = text.find(chunk, max(0, offset))
|
|
metadata["start_index"] = index
|
|
previous_chunk_len = len(chunk)
|
|
new_doc = Document(page_content=chunk, metadata=metadata)
|
|
documents.append(new_doc)
|
|
return documents
|
|
|
|
def split_documents(self, documents: Iterable[Document]) -> list[Document]:
|
|
"""Split documents."""
|
|
texts, metadatas = [], []
|
|
for doc in documents:
|
|
texts.append(doc.page_content)
|
|
metadatas.append(doc.metadata)
|
|
return self.create_documents(texts, metadatas=metadatas)
|
|
|
|
def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
|
|
text = separator.join(docs)
|
|
if self._strip_whitespace:
|
|
text = text.strip()
|
|
return text or None
|
|
|
|
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
|
|
# We now want to combine these smaller pieces into medium size
|
|
# chunks to send to the LLM.
|
|
separator_len = self._length_function(separator)
|
|
|
|
docs = []
|
|
current_doc: list[str] = []
|
|
total = 0
|
|
for d in splits:
|
|
len_ = self._length_function(d)
|
|
if (
|
|
total + len_ + (separator_len if len(current_doc) > 0 else 0)
|
|
> self._chunk_size
|
|
):
|
|
if total > self._chunk_size:
|
|
logger.warning(
|
|
"Created a chunk of size %s, "
|
|
"which is longer than the specified %s",
|
|
total,
|
|
self._chunk_size,
|
|
)
|
|
if len(current_doc) > 0:
|
|
doc = self._join_docs(current_doc, separator)
|
|
if doc is not None:
|
|
docs.append(doc)
|
|
# Keep on popping if:
|
|
# - we have a larger chunk than in the chunk overlap
|
|
# - or if we still have any chunks and the length is long
|
|
while total > self._chunk_overlap or (
|
|
total + len_ + (separator_len if len(current_doc) > 0 else 0)
|
|
> self._chunk_size
|
|
and total > 0
|
|
):
|
|
total -= self._length_function(current_doc[0]) + (
|
|
separator_len if len(current_doc) > 1 else 0
|
|
)
|
|
current_doc = current_doc[1:]
|
|
current_doc.append(d)
|
|
total += len_ + (separator_len if len(current_doc) > 1 else 0)
|
|
doc = self._join_docs(current_doc, separator)
|
|
if doc is not None:
|
|
docs.append(doc)
|
|
return docs
|
|
|
|
@classmethod
|
|
def from_huggingface_tokenizer(
|
|
cls, tokenizer: PreTrainedTokenizerBase, **kwargs: Any
|
|
) -> TextSplitter:
|
|
"""Text splitter that uses HuggingFace tokenizer to count length."""
|
|
if not _HAS_TRANSFORMERS:
|
|
msg = (
|
|
"Could not import transformers python package. "
|
|
"Please install it with `pip install transformers`."
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
|
msg = "Tokenizer received was not an instance of PreTrainedTokenizerBase" # type: ignore[unreachable]
|
|
raise ValueError(msg) # noqa: TRY004
|
|
|
|
def _huggingface_tokenizer_length(text: str) -> int:
|
|
return len(tokenizer.tokenize(text))
|
|
|
|
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
|
|
|
@classmethod
|
|
def from_tiktoken_encoder(
|
|
cls,
|
|
encoding_name: str = "gpt2",
|
|
model_name: Optional[str] = None,
|
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
**kwargs: Any,
|
|
) -> Self:
|
|
"""Text splitter that uses tiktoken encoder to count length."""
|
|
if not _HAS_TIKTOKEN:
|
|
msg = (
|
|
"Could not import tiktoken python package. "
|
|
"This is needed in order to calculate max_tokens_for_prompt. "
|
|
"Please install it with `pip install tiktoken`."
|
|
)
|
|
raise ImportError(msg)
|
|
|
|
if model_name is not None:
|
|
enc = tiktoken.encoding_for_model(model_name)
|
|
else:
|
|
enc = tiktoken.get_encoding(encoding_name)
|
|
|
|
def _tiktoken_encoder(text: str) -> int:
|
|
return len(
|
|
enc.encode(
|
|
text,
|
|
allowed_special=allowed_special,
|
|
disallowed_special=disallowed_special,
|
|
)
|
|
)
|
|
|
|
if issubclass(cls, TokenTextSplitter):
|
|
extra_kwargs = {
|
|
"encoding_name": encoding_name,
|
|
"model_name": model_name,
|
|
"allowed_special": allowed_special,
|
|
"disallowed_special": disallowed_special,
|
|
}
|
|
kwargs = {**kwargs, **extra_kwargs}
|
|
|
|
return cls(length_function=_tiktoken_encoder, **kwargs)
|
|
|
|
@override
|
|
def transform_documents(
|
|
self, documents: Sequence[Document], **kwargs: Any
|
|
) -> Sequence[Document]:
|
|
"""Transform sequence of documents by splitting them."""
|
|
return self.split_documents(list(documents))
|
|
|
|
|
|
class TokenTextSplitter(TextSplitter):
|
|
"""Splitting text to tokens using model tokenizer."""
|
|
|
|
def __init__(
|
|
self,
|
|
encoding_name: str = "gpt2",
|
|
model_name: Optional[str] = None,
|
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Create a new TextSplitter."""
|
|
super().__init__(**kwargs)
|
|
if not _HAS_TIKTOKEN:
|
|
msg = (
|
|
"Could not import tiktoken python package. "
|
|
"This is needed in order to for TokenTextSplitter. "
|
|
"Please install it with `pip install tiktoken`."
|
|
)
|
|
raise ImportError(msg)
|
|
|
|
if model_name is not None:
|
|
enc = tiktoken.encoding_for_model(model_name)
|
|
else:
|
|
enc = tiktoken.get_encoding(encoding_name)
|
|
self._tokenizer = enc
|
|
self._allowed_special = allowed_special
|
|
self._disallowed_special = disallowed_special
|
|
|
|
def split_text(self, text: str) -> list[str]:
|
|
"""Splits the input text into smaller chunks based on tokenization.
|
|
|
|
This method uses a custom tokenizer configuration to encode the input text
|
|
into tokens, processes the tokens in chunks of a specified size with overlap,
|
|
and decodes them back into text chunks. The splitting is performed using the
|
|
`split_text_on_tokens` function.
|
|
|
|
Args:
|
|
text (str): The input text to be split into smaller chunks.
|
|
|
|
Returns:
|
|
List[str]: A list of text chunks, where each chunk is derived from a portion
|
|
of the input text based on the tokenization and chunking rules.
|
|
"""
|
|
|
|
def _encode(_text: str) -> list[int]:
|
|
return self._tokenizer.encode(
|
|
_text,
|
|
allowed_special=self._allowed_special,
|
|
disallowed_special=self._disallowed_special,
|
|
)
|
|
|
|
tokenizer = Tokenizer(
|
|
chunk_overlap=self._chunk_overlap,
|
|
tokens_per_chunk=self._chunk_size,
|
|
decode=self._tokenizer.decode,
|
|
encode=_encode,
|
|
)
|
|
|
|
return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
|
|
|
|
|
class Language(str, Enum):
|
|
"""Enum of the programming languages."""
|
|
|
|
CPP = "cpp"
|
|
GO = "go"
|
|
JAVA = "java"
|
|
KOTLIN = "kotlin"
|
|
JS = "js"
|
|
TS = "ts"
|
|
PHP = "php"
|
|
PROTO = "proto"
|
|
PYTHON = "python"
|
|
RST = "rst"
|
|
RUBY = "ruby"
|
|
RUST = "rust"
|
|
SCALA = "scala"
|
|
SWIFT = "swift"
|
|
MARKDOWN = "markdown"
|
|
LATEX = "latex"
|
|
HTML = "html"
|
|
SOL = "sol"
|
|
CSHARP = "csharp"
|
|
COBOL = "cobol"
|
|
C = "c"
|
|
LUA = "lua"
|
|
PERL = "perl"
|
|
HASKELL = "haskell"
|
|
ELIXIR = "elixir"
|
|
POWERSHELL = "powershell"
|
|
VISUALBASIC6 = "visualbasic6"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Tokenizer:
|
|
"""Tokenizer data class."""
|
|
|
|
chunk_overlap: int
|
|
"""Overlap in tokens between chunks"""
|
|
tokens_per_chunk: int
|
|
"""Maximum number of tokens per chunk"""
|
|
decode: Callable[[list[int]], str]
|
|
""" Function to decode a list of token ids to a string"""
|
|
encode: Callable[[str], list[int]]
|
|
""" Function to encode a string to a list of token ids"""
|
|
|
|
|
|
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
|
|
"""Split incoming text and return chunks using tokenizer."""
|
|
splits: list[str] = []
|
|
input_ids = tokenizer.encode(text)
|
|
start_idx = 0
|
|
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
|
chunk_ids = input_ids[start_idx:cur_idx]
|
|
while start_idx < len(input_ids):
|
|
splits.append(tokenizer.decode(chunk_ids))
|
|
if cur_idx == len(input_ids):
|
|
break
|
|
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
|
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
|
chunk_ids = input_ids[start_idx:cur_idx]
|
|
return splits
|