diff --git a/libs/text-splitters/langchain_text_splitters/base.py b/libs/text-splitters/langchain_text_splitters/base.py index 2f3891f64f3..7ff118f9a3f 100644 --- a/libs/text-splitters/langchain_text_splitters/base.py +++ b/libs/text-splitters/langchain_text_splitters/base.py @@ -1,4 +1,5 @@ from __future__ import annotations +import os import copy import logging @@ -74,6 +75,8 @@ class TextSplitter(BaseDocumentTransformer, ABC): self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None ) -> list[Document]: """Create documents from a list of texts.""" + if isinstance(self,TokenTextSplitter): + return self.token_create_documents(texts,metadatas) _metadatas = metadatas or [{}] * len(texts) documents = [] for i, text in enumerate(texts): @@ -285,6 +288,30 @@ class TokenTextSplitter(TextSplitter): return split_text_on_tokens(text=text, tokenizer=tokenizer) + def create_documents( + self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None +) -> list[Document]: + """Override to create documents from a list of tokens.""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + metadata = _metadatas[i] + input_ids = self._tokenizer.encode(text) + start_idx = 0 + char_index = 0 + while start_idx < len(input_ids): + end_idx = min(start_idx + self._chunk_size, len(input_ids)) + chunk_ids = input_ids[start_idx:end_idx] + chunk_text = self._tokenizer.decode(chunk_ids) + if self._add_start_index: + char_index = text.find(chunk_text, char_index) + metadata["start_index"] = char_index + documents.append(Document(page_content=chunk_text,metadata=metadata)) + if end_idx == len(input_ids): + break + start_idx += self._chunk_size - self._chunk_overlap + return documents + class Language(str, Enum): """Enum of the programming languages.""" diff --git a/libs/text-splitters/tests/unit_tests/test_text_splitters.py b/libs/text-splitters/tests/unit_tests/test_text_splitters.py index 0d72e806309..683a57b026d 100644 --- a/libs/text-splitters/tests/unit_tests/test_text_splitters.py +++ b/libs/text-splitters/tests/unit_tests/test_text_splitters.py @@ -13,6 +13,7 @@ from langchain_text_splitters import ( RecursiveCharacterTextSplitter, TextSplitter, Tokenizer, + TokenTextSplitter ) from langchain_text_splitters.base import split_text_on_tokens from langchain_text_splitters.character import CharacterTextSplitter @@ -3666,3 +3667,40 @@ def test_character_text_splitter_chunk_size_effect( keep_separator=False, ) assert splitter.split_text(text) == expected + +def test_token_splitter_create_documents() -> None: + splitter = TokenTextSplitter( + add_start_index=True, + chunk_size=10, + chunk_overlap=5 + ) + text=""" + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, + sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. + Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. + Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum." + """ + docs = splitter.create_documents([text]) + for doc in docs: + s_i = doc.metadata["start_index"] + assert text[s_i : s_i + len(doc.page_content)] == doc.page_content + +def test_token_splitter_create_documents_repeat_text() -> None: + splitter = TokenTextSplitter( + add_start_index=True, + chunk_size=10, + chunk_overlap=5 + ) + text=""" + "the quick brown fox jumped over the lazy fox + the quick brown fox jumped over the lazy fox + the quick brown fox jumped over the lazy fox + the quick brown fox jumped over the lazy fox + the quick brown fox jumped over the lazy fox" + """ + docs = splitter.create_documents([text]) + for doc in docs: + s_i = doc.metadata["start_index"] + assert text[s_i : s_i + len(doc.page_content)] == doc.page_content +