text-splitters: Ruff autofixes (#31858)

Auto-fixes from ruff with rule `ALL`
This commit is contained in:
Christophe Bornet 2025-07-07 16:06:08 +02:00 committed by GitHub
parent 8aed3b61a9
commit 451c90fefa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 121 additions and 123 deletions

View File

@ -54,28 +54,28 @@ from langchain_text_splitters.sentence_transformers import (
from langchain_text_splitters.spacy import SpacyTextSplitter from langchain_text_splitters.spacy import SpacyTextSplitter
__all__ = [ __all__ = [
"TokenTextSplitter", "CharacterTextSplitter",
"TextSplitter",
"Tokenizer",
"Language",
"RecursiveCharacterTextSplitter",
"RecursiveJsonSplitter",
"LatexTextSplitter",
"JSFrameworkTextSplitter",
"PythonCodeTextSplitter",
"KonlpyTextSplitter",
"SpacyTextSplitter",
"NLTKTextSplitter",
"split_text_on_tokens",
"SentenceTransformersTokenTextSplitter",
"ElementType", "ElementType",
"HeaderType", "ExperimentalMarkdownSyntaxTextSplitter",
"LineType",
"HTMLHeaderTextSplitter", "HTMLHeaderTextSplitter",
"HTMLSectionSplitter", "HTMLSectionSplitter",
"HTMLSemanticPreservingSplitter", "HTMLSemanticPreservingSplitter",
"HeaderType",
"JSFrameworkTextSplitter",
"KonlpyTextSplitter",
"Language",
"LatexTextSplitter",
"LineType",
"MarkdownHeaderTextSplitter", "MarkdownHeaderTextSplitter",
"MarkdownTextSplitter", "MarkdownTextSplitter",
"CharacterTextSplitter", "NLTKTextSplitter",
"ExperimentalMarkdownSyntaxTextSplitter", "PythonCodeTextSplitter",
"RecursiveCharacterTextSplitter",
"RecursiveJsonSplitter",
"SentenceTransformersTokenTextSplitter",
"SpacyTextSplitter",
"TextSplitter",
"TokenTextSplitter",
"Tokenizer",
"split_text_on_tokens",
] ]

View File

@ -3,7 +3,8 @@ from __future__ import annotations
import copy import copy
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Collection, Iterable, Sequence, Set from collections.abc import Collection, Iterable, Sequence
from collections.abc import Set as AbstractSet
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import ( from typing import (
@ -47,10 +48,11 @@ class TextSplitter(BaseDocumentTransformer, ABC):
every document every document
""" """
if chunk_overlap > chunk_size: if chunk_overlap > chunk_size:
raise ValueError( msg = (
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller." f"({chunk_size}), should be smaller."
) )
raise ValueError(msg)
self._chunk_size = chunk_size self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap self._chunk_overlap = chunk_overlap
self._length_function = length_function self._length_function = length_function
@ -96,8 +98,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
text = text.strip() text = text.strip()
if text == "": if text == "":
return None return None
else: return text
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
# We now want to combine these smaller pieces into medium size # We now want to combine these smaller pieces into medium size
@ -148,18 +149,20 @@ class TextSplitter(BaseDocumentTransformer, ABC):
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
if not isinstance(tokenizer, PreTrainedTokenizerBase): if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError( msg = (
"Tokenizer received was not an instance of PreTrainedTokenizerBase" "Tokenizer received was not an instance of PreTrainedTokenizerBase"
) )
raise ValueError(msg)
def _huggingface_tokenizer_length(text: str) -> int: def _huggingface_tokenizer_length(text: str) -> int:
return len(tokenizer.tokenize(text)) return len(tokenizer.tokenize(text))
except ImportError: except ImportError:
raise ValueError( msg = (
"Could not import transformers python package. " "Could not import transformers python package. "
"Please install it with `pip install transformers`." "Please install it with `pip install transformers`."
) )
raise ValueError(msg)
return cls(length_function=_huggingface_tokenizer_length, **kwargs) return cls(length_function=_huggingface_tokenizer_length, **kwargs)
@classmethod @classmethod
@ -167,7 +170,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
cls: type[TS], cls: type[TS],
encoding_name: str = "gpt2", encoding_name: str = "gpt2",
model_name: Optional[str] = None, model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(), allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all", disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any, **kwargs: Any,
) -> TS: ) -> TS:
@ -175,11 +178,12 @@ class TextSplitter(BaseDocumentTransformer, ABC):
try: try:
import tiktoken import tiktoken
except ImportError: except ImportError:
raise ImportError( msg = (
"Could not import tiktoken python package. " "Could not import tiktoken python package. "
"This is needed in order to calculate max_tokens_for_prompt. " "This is needed in order to calculate max_tokens_for_prompt. "
"Please install it with `pip install tiktoken`." "Please install it with `pip install tiktoken`."
) )
raise ImportError(msg)
if model_name is not None: if model_name is not None:
enc = tiktoken.encoding_for_model(model_name) enc = tiktoken.encoding_for_model(model_name)
@ -220,7 +224,7 @@ class TokenTextSplitter(TextSplitter):
self, self,
encoding_name: str = "gpt2", encoding_name: str = "gpt2",
model_name: Optional[str] = None, model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(), allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all", disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -229,11 +233,12 @@ class TokenTextSplitter(TextSplitter):
try: try:
import tiktoken import tiktoken
except ImportError: except ImportError:
raise ImportError( msg = (
"Could not import tiktoken python package. " "Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. " "This is needed in order to for TokenTextSplitter. "
"Please install it with `pip install tiktoken`." "Please install it with `pip install tiktoken`."
) )
raise ImportError(msg)
if model_name is not None: if model_name is not None:
enc = tiktoken.encoding_for_model(model_name) enc = tiktoken.encoding_for_model(model_name)

View File

@ -60,9 +60,9 @@ def _split_text_with_regex(
if len(_splits) % 2 == 0: if len(_splits) % 2 == 0:
splits += _splits[-1:] splits += _splits[-1:]
splits = ( splits = (
(splits + [_splits[-1]]) ([*splits, _splits[-1]])
if keep_separator == "end" if keep_separator == "end"
else ([_splits[0]] + splits) else ([_splits[0], *splits])
) )
else: else:
splits = re.split(separator, text) splits = re.split(separator, text)
@ -170,7 +170,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
Returns: Returns:
List[str]: A list of separators appropriate for the specified language. List[str]: A list of separators appropriate for the specified language.
""" """
if language == Language.C or language == Language.CPP: if language in (Language.C, Language.CPP):
return [ return [
# Split along class definitions # Split along class definitions
"\nclass ", "\nclass ",
@ -191,7 +191,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.GO: if language == Language.GO:
return [ return [
# Split along function definitions # Split along function definitions
"\nfunc ", "\nfunc ",
@ -209,7 +209,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.JAVA: if language == Language.JAVA:
return [ return [
# Split along class definitions # Split along class definitions
"\nclass ", "\nclass ",
@ -230,7 +230,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.KOTLIN: if language == Language.KOTLIN:
return [ return [
# Split along class definitions # Split along class definitions
"\nclass ", "\nclass ",
@ -256,7 +256,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.JS: if language == Language.JS:
return [ return [
# Split along function definitions # Split along function definitions
"\nfunction ", "\nfunction ",
@ -277,7 +277,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.TS: if language == Language.TS:
return [ return [
"\nenum ", "\nenum ",
"\ninterface ", "\ninterface ",
@ -303,7 +303,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.PHP: if language == Language.PHP:
return [ return [
# Split along function definitions # Split along function definitions
"\nfunction ", "\nfunction ",
@ -322,7 +322,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.PROTO: if language == Language.PROTO:
return [ return [
# Split along message definitions # Split along message definitions
"\nmessage ", "\nmessage ",
@ -342,7 +342,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.PYTHON: if language == Language.PYTHON:
return [ return [
# First, try to split along class definitions # First, try to split along class definitions
"\nclass ", "\nclass ",
@ -354,7 +354,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.RST: if language == Language.RST:
return [ return [
# Split along section titles # Split along section titles
"\n=+\n", "\n=+\n",
@ -368,7 +368,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.RUBY: if language == Language.RUBY:
return [ return [
# Split along method definitions # Split along method definitions
"\ndef ", "\ndef ",
@ -387,7 +387,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.ELIXIR: if language == Language.ELIXIR:
return [ return [
# Split along method function and module definition # Split along method function and module definition
"\ndef ", "\ndef ",
@ -411,7 +411,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.RUST: if language == Language.RUST:
return [ return [
# Split along function definitions # Split along function definitions
"\nfn ", "\nfn ",
@ -430,7 +430,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.SCALA: if language == Language.SCALA:
return [ return [
# Split along class definitions # Split along class definitions
"\nclass ", "\nclass ",
@ -451,7 +451,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.SWIFT: if language == Language.SWIFT:
return [ return [
# Split along function definitions # Split along function definitions
"\nfunc ", "\nfunc ",
@ -472,7 +472,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.MARKDOWN: if language == Language.MARKDOWN:
return [ return [
# First, try to split along Markdown headings (starting with level 2) # First, try to split along Markdown headings (starting with level 2)
"\n#{1,6} ", "\n#{1,6} ",
@ -492,7 +492,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.LATEX: if language == Language.LATEX:
return [ return [
# First, try to split along Latex sections # First, try to split along Latex sections
"\n\\\\chapter{", "\n\\\\chapter{",
@ -516,7 +516,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.HTML: if language == Language.HTML:
return [ return [
# First, try to split along HTML tags # First, try to split along HTML tags
"<body", "<body",
@ -548,7 +548,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
"<title", "<title",
"", "",
] ]
elif language == Language.CSHARP: if language == Language.CSHARP:
return [ return [
"\ninterface ", "\ninterface ",
"\nenum ", "\nenum ",
@ -585,7 +585,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.SOL: if language == Language.SOL:
return [ return [
# Split along compiler information definitions # Split along compiler information definitions
"\npragma ", "\npragma ",
@ -615,7 +615,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.COBOL: if language == Language.COBOL:
return [ return [
# Split along divisions # Split along divisions
"\nIDENTIFICATION DIVISION.", "\nIDENTIFICATION DIVISION.",
@ -647,7 +647,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.LUA: if language == Language.LUA:
return [ return [
# Split along variable and table definitions # Split along variable and table definitions
"\nlocal ", "\nlocal ",
@ -664,7 +664,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.HASKELL: if language == Language.HASKELL:
return [ return [
# Split along function definitions # Split along function definitions
"\nmain :: ", "\nmain :: ",
@ -703,7 +703,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language == Language.POWERSHELL: if language == Language.POWERSHELL:
return [ return [
# Split along function definitions # Split along function definitions
"\nfunction ", "\nfunction ",
@ -727,10 +727,10 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ", " ",
"", "",
] ]
elif language in Language._value2member_map_: if language in Language._value2member_map_:
raise ValueError(f"Language {language} is not implemented yet!") msg = f"Language {language} is not implemented yet!"
else: raise ValueError(msg)
raise ValueError( msg = (
f"Language {language} is not supported! " f"Language {language} is not supported! Please choose from {list(Language)}"
f"Please choose from {list(Language)}" )
) raise ValueError(msg)

View File

@ -194,9 +194,10 @@ class HTMLHeaderTextSplitter:
try: try:
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
except ImportError as e: except ImportError as e:
raise ImportError( msg = (
"Unable to import BeautifulSoup. Please install via `pip install bs4`." "Unable to import BeautifulSoup. Please install via `pip install bs4`."
) from e )
raise ImportError(msg) from e
soup = BeautifulSoup(html_content, "html.parser") soup = BeautifulSoup(html_content, "html.parser")
body = soup.body if soup.body else soup body = soup.body if soup.body else soup
@ -352,7 +353,7 @@ class HTMLSectionSplitter:
for chunk in self.split_text(text): for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i]) metadata = copy.deepcopy(_metadatas[i])
for key in chunk.metadata.keys(): for key in chunk.metadata:
if chunk.metadata[key] == "#TITLE#": if chunk.metadata[key] == "#TITLE#":
chunk.metadata[key] = metadata["Title"] chunk.metadata[key] = metadata["Title"]
metadata = {**metadata, **chunk.metadata} metadata = {**metadata, **chunk.metadata}
@ -382,17 +383,16 @@ class HTMLSectionSplitter:
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from bs4.element import PageElement from bs4.element import PageElement
except ImportError as e: except ImportError as e:
raise ImportError( msg = "Unable to import BeautifulSoup/PageElement, \
"Unable to import BeautifulSoup/PageElement, \
please install with `pip install \ please install with `pip install \
bs4`." bs4`."
) from e raise ImportError(msg) from e
soup = BeautifulSoup(html_doc, "html.parser") soup = BeautifulSoup(html_doc, "html.parser")
headers = list(self.headers_to_split_on.keys()) headers = list(self.headers_to_split_on.keys())
sections: list[dict[str, str | None]] = [] sections: list[dict[str, str | None]] = []
headers = soup.find_all(["body"] + headers) # type: ignore[assignment] headers = soup.find_all(["body", *headers]) # type: ignore[assignment]
for i, header in enumerate(headers): for i, header in enumerate(headers):
header_element = cast(PageElement, header) header_element = cast(PageElement, header)
@ -441,9 +441,8 @@ class HTMLSectionSplitter:
try: try:
from lxml import etree from lxml import etree
except ImportError as e: except ImportError as e:
raise ImportError( msg = "Unable to import lxml, please install with `pip install lxml`."
"Unable to import lxml, please install with `pip install lxml`." raise ImportError(msg) from e
) from e
# use lxml library to parse html document and return xml ElementTree # use lxml library to parse html document and return xml ElementTree
# Create secure parsers to prevent XXE attacks # Create secure parsers to prevent XXE attacks
html_parser = etree.HTMLParser(no_network=True) html_parser = etree.HTMLParser(no_network=True)
@ -594,10 +593,11 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
self._BeautifulSoup = BeautifulSoup self._BeautifulSoup = BeautifulSoup
self._Tag = Tag self._Tag = Tag
except ImportError: except ImportError:
raise ImportError( msg = (
"Could not import BeautifulSoup. " "Could not import BeautifulSoup. "
"Please install it with 'pip install bs4'." "Please install it with 'pip install bs4'."
) )
raise ImportError(msg)
self._headers_to_split_on = sorted(headers_to_split_on) self._headers_to_split_on = sorted(headers_to_split_on)
self._max_chunk_size = max_chunk_size self._max_chunk_size = max_chunk_size
@ -646,9 +646,10 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
nltk.download("stopwords") nltk.download("stopwords")
self._stopwords = set(nltk.corpus.stopwords.words(self._stopword_lang)) self._stopwords = set(nltk.corpus.stopwords.words(self._stopword_lang))
except ImportError: except ImportError:
raise ImportError( msg = (
"Could not import nltk. Please install it with 'pip install nltk'." "Could not import nltk. Please install it with 'pip install nltk'."
) )
raise ImportError(msg)
def split_text(self, text: str) -> list[Document]: def split_text(self, text: str) -> list[Document]:
"""Splits the provided HTML text into smaller chunks based on the configuration. """Splits the provided HTML text into smaller chunks based on the configuration.
@ -927,8 +928,7 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
content, preserved_elements content, preserved_elements
) )
return [Document(page_content=page_content, metadata=metadata)] return [Document(page_content=page_content, metadata=metadata)]
else: return self._further_split_chunk(content, metadata, preserved_elements)
return self._further_split_chunk(content, metadata, preserved_elements)
def _further_split_chunk( def _further_split_chunk(
self, content: str, metadata: dict[Any, Any], preserved_elements: dict[str, str] self, content: str, metadata: dict[Any, Any], preserved_elements: dict[str, str]

View File

@ -64,15 +64,14 @@ class RecursiveJsonSplitter:
if isinstance(data, dict): if isinstance(data, dict):
# Process each key-value pair in the dictionary # Process each key-value pair in the dictionary
return {k: self._list_to_dict_preprocessing(v) for k, v in data.items()} return {k: self._list_to_dict_preprocessing(v) for k, v in data.items()}
elif isinstance(data, list): if isinstance(data, list):
# Convert the list to a dictionary with index-based keys # Convert the list to a dictionary with index-based keys
return { return {
str(i): self._list_to_dict_preprocessing(item) str(i): self._list_to_dict_preprocessing(item)
for i, item in enumerate(data) for i, item in enumerate(data)
} }
else: # Base case: the item is neither a dict nor a list, so return it unchanged
# Base case: the item is neither a dict nor a list, so return it unchanged return data
return data
def _json_split( def _json_split(
self, self,
@ -85,7 +84,7 @@ class RecursiveJsonSplitter:
chunks = chunks if chunks is not None else [{}] chunks = chunks if chunks is not None else [{}]
if isinstance(data, dict): if isinstance(data, dict):
for key, value in data.items(): for key, value in data.items():
new_path = current_path + [key] new_path = [*current_path, key]
chunk_size = self._json_size(chunks[-1]) chunk_size = self._json_size(chunks[-1])
size = self._json_size({key: value}) size = self._json_size({key: value})
remaining = self.max_chunk_size - chunk_size remaining = self.max_chunk_size - chunk_size

View File

@ -94,5 +94,4 @@ class JSFrameworkTextSplitter(RecursiveCharacterTextSplitter):
+ ["<>", "\n\n", "&&\n", "||\n"] + ["<>", "\n\n", "&&\n", "||\n"]
) )
self._separators = separators self._separators = separators
chunks = super().split_text(text) return super().split_text(text)
return chunks

View File

@ -22,12 +22,11 @@ class KonlpyTextSplitter(TextSplitter):
try: try:
import konlpy import konlpy
except ImportError: except ImportError:
raise ImportError( msg = """
""" Konlpy is not installed, please install it with
Konlpy is not installed, please install it with
`pip install konlpy` `pip install konlpy`
""" """
) raise ImportError(msg)
self.kkma = konlpy.tag.Kkma() self.kkma = konlpy.tag.Kkma()
def split_text(self, text: str) -> list[str]: def split_text(self, text: str) -> list[str]:

View File

@ -121,10 +121,9 @@ class MarkdownHeaderTextSplitter:
elif stripped_line.startswith("~~~"): elif stripped_line.startswith("~~~"):
in_code_block = True in_code_block = True
opening_fence = "~~~" opening_fence = "~~~"
else: elif stripped_line.startswith(opening_fence):
if stripped_line.startswith(opening_fence): in_code_block = False
in_code_block = False opening_fence = ""
opening_fence = ""
if in_code_block: if in_code_block:
current_content.append(stripped_line) current_content.append(stripped_line)
@ -207,11 +206,10 @@ class MarkdownHeaderTextSplitter:
# aggregate these into chunks based on common metadata # aggregate these into chunks based on common metadata
if not self.return_each_line: if not self.return_each_line:
return self.aggregate_lines_to_chunks(lines_with_metadata) return self.aggregate_lines_to_chunks(lines_with_metadata)
else: return [
return [ Document(page_content=chunk["content"], metadata=chunk["metadata"])
Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata
for chunk in lines_with_metadata ]
]
class LineType(TypedDict): class LineType(TypedDict):

View File

@ -22,7 +22,8 @@ class NLTKTextSplitter(TextSplitter):
self._language = language self._language = language
self._use_span_tokenize = use_span_tokenize self._use_span_tokenize = use_span_tokenize
if self._use_span_tokenize and self._separator != "": if self._use_span_tokenize and self._separator != "":
raise ValueError("When use_span_tokenize is True, separator should be ''") msg = "When use_span_tokenize is True, separator should be ''"
raise ValueError(msg)
try: try:
import nltk import nltk
@ -31,9 +32,8 @@ class NLTKTextSplitter(TextSplitter):
else: else:
self._tokenizer = nltk.tokenize.sent_tokenize self._tokenizer = nltk.tokenize.sent_tokenize
except ImportError: except ImportError:
raise ImportError( msg = "NLTK is not installed, please install it with `pip install nltk`."
"NLTK is not installed, please install it with `pip install nltk`." raise ImportError(msg)
)
def split_text(self, text: str) -> list[str]: def split_text(self, text: str) -> list[str]:
"""Split incoming text and return chunks.""" """Split incoming text and return chunks."""

View File

@ -21,11 +21,12 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
except ImportError: except ImportError:
raise ImportError( msg = (
"Could not import sentence_transformers python package. " "Could not import sentence_transformers python package. "
"This is needed in order to for SentenceTransformersTokenTextSplitter. " "This is needed in order to for SentenceTransformersTokenTextSplitter. "
"Please install it with `pip install sentence-transformers`." "Please install it with `pip install sentence-transformers`."
) )
raise ImportError(msg)
self.model_name = model_name self.model_name = model_name
self._model = SentenceTransformer(self.model_name) self._model = SentenceTransformer(self.model_name)
@ -43,12 +44,13 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
self.tokens_per_chunk = tokens_per_chunk self.tokens_per_chunk = tokens_per_chunk
if self.tokens_per_chunk > self.maximum_tokens_per_chunk: if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
raise ValueError( msg = (
f"The token limit of the models '{self.model_name}'" f"The token limit of the models '{self.model_name}'"
f" is: {self.maximum_tokens_per_chunk}." f" is: {self.maximum_tokens_per_chunk}."
f" Argument tokens_per_chunk={self.tokens_per_chunk}" f" Argument tokens_per_chunk={self.tokens_per_chunk}"
f" > maximum token limit." f" > maximum token limit."
) )
raise ValueError(msg)
def split_text(self, text: str) -> list[str]: def split_text(self, text: str) -> list[str]:
"""Splits the input text into smaller components by splitting text on tokens. """Splits the input text into smaller components by splitting text on tokens.

View File

@ -46,9 +46,8 @@ def _make_spacy_pipeline_for_splitting(
try: try:
import spacy import spacy
except ImportError: except ImportError:
raise ImportError( msg = "Spacy is not installed, please install it with `pip install spacy`."
"Spacy is not installed, please install it with `pip install spacy`." raise ImportError(msg)
)
if pipeline == "sentencizer": if pipeline == "sentencizer":
sentencizer: Any = spacy.lang.en.English() sentencizer: Any = spacy.lang.en.English()
sentencizer.add_pipe("sentencizer") sentencizer.add_pipe("sentencizer")

View File

@ -4,4 +4,3 @@ import pytest
@pytest.mark.compile @pytest.mark.compile
def test_placeholder() -> None: def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests.""" """Used for compiling integration tests without running any real tests."""
pass

View File

@ -14,7 +14,7 @@ def setup_module() -> None:
nltk.download("punkt_tab") nltk.download("punkt_tab")
@pytest.fixture() @pytest.fixture
def spacy() -> Any: def spacy() -> Any:
try: try:
import spacy import spacy

View File

@ -13,7 +13,7 @@ from langchain_text_splitters.sentence_transformers import (
) )
@pytest.fixture() @pytest.fixture
def sentence_transformers() -> Any: def sentence_transformers() -> Any:
try: try:
import sentence_transformers import sentence_transformers

View File

@ -45,7 +45,8 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
only_core = config.getoption("--only-core") or False only_core = config.getoption("--only-core") or False
if only_extended and only_core: if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") msg = "Cannot specify both `--only-extended` and `--only-core`."
raise ValueError(msg)
for item in items: for item in items:
requires_marker = item.get_closest_marker("requires") requires_marker = item.get_closest_marker("requires")
@ -81,8 +82,5 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
) )
break break
else: elif only_extended:
if only_extended: item.add_marker(pytest.mark.skip(reason="Skipping not an extended test."))
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)

View File

@ -98,7 +98,7 @@ def test_character_text_splitter_longer_words() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"separator, is_separator_regex", [(re.escape("."), True), (".", False)] ("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
) )
def test_character_text_splitter_keep_separator_regex( def test_character_text_splitter_keep_separator_regex(
separator: str, is_separator_regex: bool separator: str, is_separator_regex: bool
@ -120,7 +120,7 @@ def test_character_text_splitter_keep_separator_regex(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"separator, is_separator_regex", [(re.escape("."), True), (".", False)] ("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
) )
def test_character_text_splitter_keep_separator_regex_start( def test_character_text_splitter_keep_separator_regex_start(
separator: str, is_separator_regex: bool separator: str, is_separator_regex: bool
@ -142,7 +142,7 @@ def test_character_text_splitter_keep_separator_regex_start(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"separator, is_separator_regex", [(re.escape("."), True), (".", False)] ("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
) )
def test_character_text_splitter_keep_separator_regex_end( def test_character_text_splitter_keep_separator_regex_end(
separator: str, is_separator_regex: bool separator: str, is_separator_regex: bool
@ -164,7 +164,7 @@ def test_character_text_splitter_keep_separator_regex_end(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"separator, is_separator_regex", [(re.escape("."), True), (".", False)] ("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
) )
def test_character_text_splitter_discard_separator_regex( def test_character_text_splitter_discard_separator_regex(
separator: str, is_separator_regex: bool separator: str, is_separator_regex: bool
@ -250,7 +250,7 @@ def test_create_documents_with_metadata() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"splitter, text, expected_docs", ("splitter", "text", "expected_docs"),
[ [
( (
CharacterTextSplitter( CharacterTextSplitter(
@ -1390,7 +1390,7 @@ def test_md_header_text_splitter_fenced_code_block(fence: str) -> None:
assert output == expected_output assert output == expected_output
@pytest.mark.parametrize(["fence", "other_fence"], [("```", "~~~"), ("~~~", "```")]) @pytest.mark.parametrize(("fence", "other_fence"), [("```", "~~~"), ("~~~", "```")])
def test_md_header_text_splitter_fenced_code_block_interleaved( def test_md_header_text_splitter_fenced_code_block_interleaved(
fence: str, other_fence: str fence: str, other_fence: str
) -> None: ) -> None:
@ -2240,7 +2240,7 @@ def html_header_splitter_splitter_factory() -> Callable[
@pytest.mark.parametrize( @pytest.mark.parametrize(
"headers_to_split_on, html_input, expected_documents, test_case", ("headers_to_split_on", "html_input", "expected_documents", "test_case"),
[ [
( (
# Test Case 1: Split on h1 and h2 # Test Case 1: Split on h1 and h2
@ -2469,7 +2469,7 @@ def test_html_header_text_splitter(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"headers_to_split_on, html_content, expected_output, test_case", ("headers_to_split_on", "html_content", "expected_output", "test_case"),
[ [
( (
# Test Case A: Split on h1 and h2 with h3 in content # Test Case A: Split on h1 and h2 with h3 in content
@ -2624,7 +2624,7 @@ def test_additional_html_header_text_splitter(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"headers_to_split_on, html_content, expected_output, test_case", ("headers_to_split_on", "html_content", "expected_output", "test_case"),
[ [
( (
# Test Case C: Split on h1, h2, and h3 with no headers present # Test Case C: Split on h1, h2, and h3 with no headers present
@ -3551,7 +3551,7 @@ def test_character_text_splitter_discard_regex_separator_on_merge() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"separator,is_regex,text,chunk_size,expected", ("separator", "is_regex", "text", "chunk_size", "expected"),
[ [
# 1) regex lookaround & split happens # 1) regex lookaround & split happens
# "abcmiddef" split by "(?<=mid)" → ["abcmid","def"], chunk_size=5 keeps both # "abcmiddef" split by "(?<=mid)" → ["abcmid","def"], chunk_size=5 keeps both