Override of create_documents for TokenTextSplitter

This commit is contained in:
Venrite 2025-07-28 16:41:09 -07:00
parent 74af25e2c1
commit a997a90b86
2 changed files with 65 additions and 0 deletions

View File

@ -1,4 +1,5 @@
from __future__ import annotations
import os
import copy
import logging
@ -74,6 +75,8 @@ class TextSplitter(BaseDocumentTransformer, ABC):
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
) -> list[Document]:
"""Create documents from a list of texts."""
if isinstance(self,TokenTextSplitter):
return self.token_create_documents(texts,metadatas)
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
@ -285,6 +288,30 @@ class TokenTextSplitter(TextSplitter):
return split_text_on_tokens(text=text, tokenizer=tokenizer)
def create_documents(
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
) -> list[Document]:
"""Override to create documents from a list of tokens."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
metadata = _metadatas[i]
input_ids = self._tokenizer.encode(text)
start_idx = 0
char_index = 0
while start_idx < len(input_ids):
end_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:end_idx]
chunk_text = self._tokenizer.decode(chunk_ids)
if self._add_start_index:
char_index = text.find(chunk_text, char_index)
metadata["start_index"] = char_index
documents.append(Document(page_content=chunk_text,metadata=metadata))
if end_idx == len(input_ids):
break
start_idx += self._chunk_size - self._chunk_overlap
return documents
class Language(str, Enum):
"""Enum of the programming languages."""

View File

@ -13,6 +13,7 @@ from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
TextSplitter,
Tokenizer,
TokenTextSplitter
)
from langchain_text_splitters.base import split_text_on_tokens
from langchain_text_splitters.character import CharacterTextSplitter
@ -3666,3 +3667,40 @@ def test_character_text_splitter_chunk_size_effect(
keep_separator=False,
)
assert splitter.split_text(text) == expected
def test_token_splitter_create_documents() -> None:
splitter = TokenTextSplitter(
add_start_index=True,
chunk_size=10,
chunk_overlap=5
)
text="""
"Lorem ipsum dolor sit amet, consectetur adipiscing elit,
sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."
"""
docs = splitter.create_documents([text])
for doc in docs:
s_i = doc.metadata["start_index"]
assert text[s_i : s_i + len(doc.page_content)] == doc.page_content
def test_token_splitter_create_documents_repeat_text() -> None:
splitter = TokenTextSplitter(
add_start_index=True,
chunk_size=10,
chunk_overlap=5
)
text="""
"the quick brown fox jumped over the lazy fox
the quick brown fox jumped over the lazy fox
the quick brown fox jumped over the lazy fox
the quick brown fox jumped over the lazy fox
the quick brown fox jumped over the lazy fox"
"""
docs = splitter.create_documents([text])
for doc in docs:
s_i = doc.metadata["start_index"]
assert text[s_i : s_i + len(doc.page_content)] == doc.page_content