text-splitters: Set strict mypy rules (#30900)

* Add strict mypy rules
* Fix mypy violations
* Add error codes to all type ignores
* Add ruff rule PGH003
* Bump mypy version to 1.15
This commit is contained in:
Christophe Bornet
2025-04-23 06:41:24 +03:00
committed by GitHub
parent eedda164c6
commit 8c5ae108dd
9 changed files with 81 additions and 77 deletions

View File

@@ -68,7 +68,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
"""Split text into multiple components."""
def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
) -> List[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)

View File

@@ -353,8 +353,8 @@ class HTMLSectionSplitter:
return self.split_text_from_file(StringIO(text))
def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Document]:
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
) -> list[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
@@ -389,10 +389,8 @@ class HTMLSectionSplitter:
- 'tag_name': The name of the header tag (e.g., "h1", "h2").
"""
try:
from bs4 import (
BeautifulSoup, # type: ignore[import-untyped]
PageElement,
)
from bs4 import BeautifulSoup
from bs4.element import PageElement
except ImportError as e:
raise ImportError(
"Unable to import BeautifulSoup/PageElement, \
@@ -411,13 +409,13 @@ class HTMLSectionSplitter:
if i == 0:
current_header = "#TITLE#"
current_header_tag = "h1"
section_content: List = []
section_content: list[str] = []
else:
current_header = header_element.text.strip()
current_header_tag = header_element.name # type: ignore[attr-defined]
section_content = []
for element in header_element.next_elements:
if i + 1 < len(headers) and element == headers[i + 1]:
if i + 1 < len(headers) and element == headers[i + 1]: # type: ignore[comparison-overlap]
break
if isinstance(element, str):
section_content.append(element)
@@ -637,8 +635,8 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
if self._stopword_removal:
try:
import nltk # type: ignore
from nltk.corpus import stopwords # type: ignore
import nltk
from nltk.corpus import stopwords # type: ignore[import-untyped]
nltk.download("stopwords")
self._stopwords = set(stopwords.words(self._stopword_lang))
@@ -902,7 +900,7 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
return documents
def _create_documents(
self, headers: dict, content: str, preserved_elements: dict
self, headers: dict[str, str], content: str, preserved_elements: dict[str, str]
) -> List[Document]:
"""Creates Document objects from the provided headers, content, and elements.
@@ -928,7 +926,7 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
return self._further_split_chunk(content, metadata, preserved_elements)
def _further_split_chunk(
self, content: str, metadata: dict, preserved_elements: dict
self, content: str, metadata: dict[Any, Any], preserved_elements: dict[str, str]
) -> List[Document]:
"""Further splits the content into smaller chunks.
@@ -959,7 +957,7 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
return result
def _reinsert_preserved_elements(
self, content: str, preserved_elements: dict
self, content: str, preserved_elements: dict[str, str]
) -> str:
"""Reinserts preserved elements into the content into their original positions.

View File

@@ -49,12 +49,12 @@ class RecursiveJsonSplitter:
)
@staticmethod
def _json_size(data: Dict) -> int:
def _json_size(data: dict[str, Any]) -> int:
"""Calculate the size of the serialized JSON object."""
return len(json.dumps(data))
@staticmethod
def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None:
def _set_nested_dict(d: dict[str, Any], path: list[str], value: Any) -> None:
"""Set a value in a nested dictionary based on the given path."""
for key in path[:-1]:
d = d.setdefault(key, {})
@@ -76,10 +76,10 @@ class RecursiveJsonSplitter:
def _json_split(
self,
data: Dict[str, Any],
current_path: Optional[List[str]] = None,
chunks: Optional[List[Dict]] = None,
) -> List[Dict]:
data: dict[str, Any],
current_path: Optional[list[str]] = None,
chunks: Optional[list[dict[str, Any]]] = None,
) -> list[dict[str, Any]]:
"""Split json into maximum size dictionaries while preserving structure."""
current_path = current_path or []
chunks = chunks if chunks is not None else [{}]
@@ -107,9 +107,9 @@ class RecursiveJsonSplitter:
def split_json(
self,
json_data: Dict[str, Any],
json_data: dict[str, Any],
convert_lists: bool = False,
) -> List[Dict]:
) -> list[dict[str, Any]]:
"""Splits JSON into a list of JSON chunks."""
if convert_lists:
chunks = self._json_split(self._list_to_dict_preprocessing(json_data))
@@ -135,11 +135,11 @@ class RecursiveJsonSplitter:
def create_documents(
self,
texts: List[Dict],
texts: list[dict[str, Any]],
convert_lists: bool = False,
ensure_ascii: bool = True,
metadatas: Optional[List[dict]] = None,
) -> List[Document]:
metadatas: Optional[list[dict[Any, Any]]] = None,
) -> list[Document]:
"""Create documents from a list of json objects (Dict)."""
_metadatas = metadatas or [{}] * len(texts)
documents = []

View File

@@ -404,18 +404,18 @@ class ExperimentalMarkdownSyntaxTextSplitter:
self.current_chunk = Document(page_content="")
# Match methods
def _match_header(self, line: str) -> Union[re.Match, None]:
def _match_header(self, line: str) -> Union[re.Match[str], None]:
match = re.match(r"^(#{1,6}) (.*)", line)
# Only matches on the configured headers
if match and match.group(1) in self.splittable_headers:
return match
return None
def _match_code(self, line: str) -> Union[re.Match, None]:
def _match_code(self, line: str) -> Union[re.Match[str], None]:
matches = [re.match(rule, line) for rule in [r"^```(.*)", r"^~~~(.*)"]]
return next((match for match in matches if match), None)
def _match_horz(self, line: str) -> Union[re.Match, None]:
def _match_horz(self, line: str) -> Union[re.Match[str], None]:
matches = [
re.match(rule, line) for rule in [r"^\*\*\*+\n", r"^---+\n", r"^___+\n"]
]

View File

@@ -35,7 +35,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
def _initialize_chunk_configuration(
self, *, tokens_per_chunk: Optional[int]
) -> None:
self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
self.maximum_tokens_per_chunk = self._model.max_seq_length
if tokens_per_chunk is None:
self.tokens_per_chunk = self.maximum_tokens_per_chunk
@@ -93,10 +93,10 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
_max_length_equal_32_bit_integer: int = 2**32
def _encode(self, text: str) -> List[int]:
def _encode(self, text: str) -> list[int]:
token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
text,
max_length=self._max_length_equal_32_bit_integer,
truncation="do_not_truncate",
)
return token_ids_with_start_and_end_token_ids
return cast("list[int]", token_ids_with_start_and_end_token_ids)