mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
text-splitters: Set strict mypy rules (#30900)
* Add strict mypy rules * Fix mypy violations * Add error codes to all type ignores * Add ruff rule PGH003 * Bump mypy version to 1.15
This commit is contained in:
committed by
GitHub
parent
eedda164c6
commit
8c5ae108dd
@@ -68,7 +68,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""Split text into multiple components."""
|
||||
|
||||
def create_documents(
|
||||
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
||||
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
|
||||
) -> List[Document]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
|
@@ -353,8 +353,8 @@ class HTMLSectionSplitter:
|
||||
return self.split_text_from_file(StringIO(text))
|
||||
|
||||
def create_documents(
|
||||
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
||||
) -> List[Document]:
|
||||
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
|
||||
) -> list[Document]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
documents = []
|
||||
@@ -389,10 +389,8 @@ class HTMLSectionSplitter:
|
||||
- 'tag_name': The name of the header tag (e.g., "h1", "h2").
|
||||
"""
|
||||
try:
|
||||
from bs4 import (
|
||||
BeautifulSoup, # type: ignore[import-untyped]
|
||||
PageElement,
|
||||
)
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4.element import PageElement
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import BeautifulSoup/PageElement, \
|
||||
@@ -411,13 +409,13 @@ class HTMLSectionSplitter:
|
||||
if i == 0:
|
||||
current_header = "#TITLE#"
|
||||
current_header_tag = "h1"
|
||||
section_content: List = []
|
||||
section_content: list[str] = []
|
||||
else:
|
||||
current_header = header_element.text.strip()
|
||||
current_header_tag = header_element.name # type: ignore[attr-defined]
|
||||
section_content = []
|
||||
for element in header_element.next_elements:
|
||||
if i + 1 < len(headers) and element == headers[i + 1]:
|
||||
if i + 1 < len(headers) and element == headers[i + 1]: # type: ignore[comparison-overlap]
|
||||
break
|
||||
if isinstance(element, str):
|
||||
section_content.append(element)
|
||||
@@ -637,8 +635,8 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
|
||||
|
||||
if self._stopword_removal:
|
||||
try:
|
||||
import nltk # type: ignore
|
||||
from nltk.corpus import stopwords # type: ignore
|
||||
import nltk
|
||||
from nltk.corpus import stopwords # type: ignore[import-untyped]
|
||||
|
||||
nltk.download("stopwords")
|
||||
self._stopwords = set(stopwords.words(self._stopword_lang))
|
||||
@@ -902,7 +900,7 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
|
||||
return documents
|
||||
|
||||
def _create_documents(
|
||||
self, headers: dict, content: str, preserved_elements: dict
|
||||
self, headers: dict[str, str], content: str, preserved_elements: dict[str, str]
|
||||
) -> List[Document]:
|
||||
"""Creates Document objects from the provided headers, content, and elements.
|
||||
|
||||
@@ -928,7 +926,7 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
|
||||
return self._further_split_chunk(content, metadata, preserved_elements)
|
||||
|
||||
def _further_split_chunk(
|
||||
self, content: str, metadata: dict, preserved_elements: dict
|
||||
self, content: str, metadata: dict[Any, Any], preserved_elements: dict[str, str]
|
||||
) -> List[Document]:
|
||||
"""Further splits the content into smaller chunks.
|
||||
|
||||
@@ -959,7 +957,7 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
|
||||
return result
|
||||
|
||||
def _reinsert_preserved_elements(
|
||||
self, content: str, preserved_elements: dict
|
||||
self, content: str, preserved_elements: dict[str, str]
|
||||
) -> str:
|
||||
"""Reinserts preserved elements into the content into their original positions.
|
||||
|
||||
|
@@ -49,12 +49,12 @@ class RecursiveJsonSplitter:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _json_size(data: Dict) -> int:
|
||||
def _json_size(data: dict[str, Any]) -> int:
|
||||
"""Calculate the size of the serialized JSON object."""
|
||||
return len(json.dumps(data))
|
||||
|
||||
@staticmethod
|
||||
def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None:
|
||||
def _set_nested_dict(d: dict[str, Any], path: list[str], value: Any) -> None:
|
||||
"""Set a value in a nested dictionary based on the given path."""
|
||||
for key in path[:-1]:
|
||||
d = d.setdefault(key, {})
|
||||
@@ -76,10 +76,10 @@ class RecursiveJsonSplitter:
|
||||
|
||||
def _json_split(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
current_path: Optional[List[str]] = None,
|
||||
chunks: Optional[List[Dict]] = None,
|
||||
) -> List[Dict]:
|
||||
data: dict[str, Any],
|
||||
current_path: Optional[list[str]] = None,
|
||||
chunks: Optional[list[dict[str, Any]]] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Split json into maximum size dictionaries while preserving structure."""
|
||||
current_path = current_path or []
|
||||
chunks = chunks if chunks is not None else [{}]
|
||||
@@ -107,9 +107,9 @@ class RecursiveJsonSplitter:
|
||||
|
||||
def split_json(
|
||||
self,
|
||||
json_data: Dict[str, Any],
|
||||
json_data: dict[str, Any],
|
||||
convert_lists: bool = False,
|
||||
) -> List[Dict]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Splits JSON into a list of JSON chunks."""
|
||||
if convert_lists:
|
||||
chunks = self._json_split(self._list_to_dict_preprocessing(json_data))
|
||||
@@ -135,11 +135,11 @@ class RecursiveJsonSplitter:
|
||||
|
||||
def create_documents(
|
||||
self,
|
||||
texts: List[Dict],
|
||||
texts: list[dict[str, Any]],
|
||||
convert_lists: bool = False,
|
||||
ensure_ascii: bool = True,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
) -> List[Document]:
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
) -> list[Document]:
|
||||
"""Create documents from a list of json objects (Dict)."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
documents = []
|
||||
|
@@ -404,18 +404,18 @@ class ExperimentalMarkdownSyntaxTextSplitter:
|
||||
self.current_chunk = Document(page_content="")
|
||||
|
||||
# Match methods
|
||||
def _match_header(self, line: str) -> Union[re.Match, None]:
|
||||
def _match_header(self, line: str) -> Union[re.Match[str], None]:
|
||||
match = re.match(r"^(#{1,6}) (.*)", line)
|
||||
# Only matches on the configured headers
|
||||
if match and match.group(1) in self.splittable_headers:
|
||||
return match
|
||||
return None
|
||||
|
||||
def _match_code(self, line: str) -> Union[re.Match, None]:
|
||||
def _match_code(self, line: str) -> Union[re.Match[str], None]:
|
||||
matches = [re.match(rule, line) for rule in [r"^```(.*)", r"^~~~(.*)"]]
|
||||
return next((match for match in matches if match), None)
|
||||
|
||||
def _match_horz(self, line: str) -> Union[re.Match, None]:
|
||||
def _match_horz(self, line: str) -> Union[re.Match[str], None]:
|
||||
matches = [
|
||||
re.match(rule, line) for rule in [r"^\*\*\*+\n", r"^---+\n", r"^___+\n"]
|
||||
]
|
||||
|
@@ -35,7 +35,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
def _initialize_chunk_configuration(
|
||||
self, *, tokens_per_chunk: Optional[int]
|
||||
) -> None:
|
||||
self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
|
||||
self.maximum_tokens_per_chunk = self._model.max_seq_length
|
||||
|
||||
if tokens_per_chunk is None:
|
||||
self.tokens_per_chunk = self.maximum_tokens_per_chunk
|
||||
@@ -93,10 +93,10 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
|
||||
_max_length_equal_32_bit_integer: int = 2**32
|
||||
|
||||
def _encode(self, text: str) -> List[int]:
|
||||
def _encode(self, text: str) -> list[int]:
|
||||
token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
|
||||
text,
|
||||
max_length=self._max_length_equal_32_bit_integer,
|
||||
truncation="do_not_truncate",
|
||||
)
|
||||
return token_ids_with_start_and_end_token_ids
|
||||
return cast("list[int]", token_ids_with_start_and_end_token_ids)
|
||||
|
Reference in New Issue
Block a user