mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-30 08:14:47 +00:00
Override of create_documents for TokenTextSplitter
This commit is contained in:
parent
74af25e2c1
commit
a997a90b86
@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
import copy
|
||||
import logging
|
||||
@ -74,6 +75,8 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
|
||||
) -> list[Document]:
|
||||
"""Create documents from a list of texts."""
|
||||
if isinstance(self,TokenTextSplitter):
|
||||
return self.token_create_documents(texts,metadatas)
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
documents = []
|
||||
for i, text in enumerate(texts):
|
||||
@ -285,6 +288,30 @@ class TokenTextSplitter(TextSplitter):
|
||||
|
||||
return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
|
||||
def create_documents(
|
||||
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
|
||||
) -> list[Document]:
|
||||
"""Override to create documents from a list of tokens."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
documents = []
|
||||
for i, text in enumerate(texts):
|
||||
metadata = _metadatas[i]
|
||||
input_ids = self._tokenizer.encode(text)
|
||||
start_idx = 0
|
||||
char_index = 0
|
||||
while start_idx < len(input_ids):
|
||||
end_idx = min(start_idx + self._chunk_size, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:end_idx]
|
||||
chunk_text = self._tokenizer.decode(chunk_ids)
|
||||
if self._add_start_index:
|
||||
char_index = text.find(chunk_text, char_index)
|
||||
metadata["start_index"] = char_index
|
||||
documents.append(Document(page_content=chunk_text,metadata=metadata))
|
||||
if end_idx == len(input_ids):
|
||||
break
|
||||
start_idx += self._chunk_size - self._chunk_overlap
|
||||
return documents
|
||||
|
||||
class Language(str, Enum):
|
||||
"""Enum of the programming languages."""
|
||||
|
||||
|
@ -13,6 +13,7 @@ from langchain_text_splitters import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
TextSplitter,
|
||||
Tokenizer,
|
||||
TokenTextSplitter
|
||||
)
|
||||
from langchain_text_splitters.base import split_text_on_tokens
|
||||
from langchain_text_splitters.character import CharacterTextSplitter
|
||||
@ -3666,3 +3667,40 @@ def test_character_text_splitter_chunk_size_effect(
|
||||
keep_separator=False,
|
||||
)
|
||||
assert splitter.split_text(text) == expected
|
||||
|
||||
def test_token_splitter_create_documents() -> None:
|
||||
splitter = TokenTextSplitter(
|
||||
add_start_index=True,
|
||||
chunk_size=10,
|
||||
chunk_overlap=5
|
||||
)
|
||||
text="""
|
||||
"Lorem ipsum dolor sit amet, consectetur adipiscing elit,
|
||||
sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."
|
||||
"""
|
||||
docs = splitter.create_documents([text])
|
||||
for doc in docs:
|
||||
s_i = doc.metadata["start_index"]
|
||||
assert text[s_i : s_i + len(doc.page_content)] == doc.page_content
|
||||
|
||||
def test_token_splitter_create_documents_repeat_text() -> None:
|
||||
splitter = TokenTextSplitter(
|
||||
add_start_index=True,
|
||||
chunk_size=10,
|
||||
chunk_overlap=5
|
||||
)
|
||||
text="""
|
||||
"the quick brown fox jumped over the lazy fox
|
||||
the quick brown fox jumped over the lazy fox
|
||||
the quick brown fox jumped over the lazy fox
|
||||
the quick brown fox jumped over the lazy fox
|
||||
the quick brown fox jumped over the lazy fox"
|
||||
"""
|
||||
docs = splitter.create_documents([text])
|
||||
for doc in docs:
|
||||
s_i = doc.metadata["start_index"]
|
||||
assert text[s_i : s_i + len(doc.page_content)] == doc.page_content
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user