Running make lint, make format, and make test

This commit is contained in:
Venrite 2025-07-28 17:07:31 -07:00
parent d50965f7b2
commit 68bede3e24
2 changed files with 16 additions and 24 deletions

View File

@ -74,8 +74,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
) -> list[Document]:
"""Create documents from a list of texts."""
if isinstance(self,TokenTextSplitter):
return self.token_create_documents(texts,metadatas)
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
@ -227,6 +225,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
class TokenTextSplitter(TextSplitter):
"""Splitting text to tokens using model tokenizer."""
def __init__(
self,
encoding_name: str = "gpt2",
@ -286,10 +285,10 @@ class TokenTextSplitter(TextSplitter):
)
return split_text_on_tokens(text=text, tokenizer=tokenizer)
def create_documents(
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
) -> list[Document]:
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
) -> list[Document]:
"""Override to create documents from a list of tokens."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
@ -299,18 +298,19 @@ class TokenTextSplitter(TextSplitter):
start_idx = 0
char_index = 0
while start_idx < len(input_ids):
end_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:end_idx]
end_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:end_idx]
chunk_text = self._tokenizer.decode(chunk_ids)
if self._add_start_index:
char_index = text.find(chunk_text, char_index)
metadata["start_index"] = char_index
documents.append(Document(page_content=chunk_text,metadata=metadata))
documents.append(Document(page_content=chunk_text, metadata=metadata))
if end_idx == len(input_ids):
break
start_idx += self._chunk_size - self._chunk_overlap
return documents
class Language(str, Enum):
"""Enum of the programming languages."""
@ -372,4 +372,3 @@ def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits

View File

@ -13,7 +13,7 @@ from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
TextSplitter,
Tokenizer,
TokenTextSplitter
TokenTextSplitter,
)
from langchain_text_splitters.base import split_text_on_tokens
from langchain_text_splitters.character import CharacterTextSplitter
@ -3668,13 +3668,10 @@ def test_character_text_splitter_chunk_size_effect(
)
assert splitter.split_text(text) == expected
def test_token_splitter_create_documents() -> None:
splitter = TokenTextSplitter(
add_start_index=True,
chunk_size=10,
chunk_overlap=5
)
text="""
splitter = TokenTextSplitter(add_start_index=True, chunk_size=10, chunk_overlap=5)
text = """
"Lorem ipsum dolor sit amet, consectetur adipiscing elit,
sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
@ -3686,13 +3683,10 @@ def test_token_splitter_create_documents() -> None:
s_i = doc.metadata["start_index"]
assert text[s_i : s_i + len(doc.page_content)] == doc.page_content
def test_token_splitter_create_documents_repeat_text() -> None:
splitter = TokenTextSplitter(
add_start_index=True,
chunk_size=10,
chunk_overlap=5
)
text="""
splitter = TokenTextSplitter(add_start_index=True,chunk_size=10,chunk_overlap=5)
text = """
"the quick brown fox jumped over the lazy fox
the quick brown fox jumped over the lazy fox
the quick brown fox jumped over the lazy fox
@ -3703,4 +3697,3 @@ def test_token_splitter_create_documents_repeat_text() -> None:
for doc in docs:
s_i = doc.metadata["start_index"]
assert text[s_i : s_i + len(doc.page_content)] == doc.page_content