mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 22:45:49 +00:00
text-splitters: Ruff autofixes (#31858)
Auto-fixes from ruff with rule `ALL`
This commit is contained in:
parent
8aed3b61a9
commit
451c90fefa
@ -54,28 +54,28 @@ from langchain_text_splitters.sentence_transformers import (
|
||||
from langchain_text_splitters.spacy import SpacyTextSplitter
|
||||
|
||||
__all__ = [
|
||||
"TokenTextSplitter",
|
||||
"TextSplitter",
|
||||
"Tokenizer",
|
||||
"Language",
|
||||
"RecursiveCharacterTextSplitter",
|
||||
"RecursiveJsonSplitter",
|
||||
"LatexTextSplitter",
|
||||
"JSFrameworkTextSplitter",
|
||||
"PythonCodeTextSplitter",
|
||||
"KonlpyTextSplitter",
|
||||
"SpacyTextSplitter",
|
||||
"NLTKTextSplitter",
|
||||
"split_text_on_tokens",
|
||||
"SentenceTransformersTokenTextSplitter",
|
||||
"CharacterTextSplitter",
|
||||
"ElementType",
|
||||
"HeaderType",
|
||||
"LineType",
|
||||
"ExperimentalMarkdownSyntaxTextSplitter",
|
||||
"HTMLHeaderTextSplitter",
|
||||
"HTMLSectionSplitter",
|
||||
"HTMLSemanticPreservingSplitter",
|
||||
"HeaderType",
|
||||
"JSFrameworkTextSplitter",
|
||||
"KonlpyTextSplitter",
|
||||
"Language",
|
||||
"LatexTextSplitter",
|
||||
"LineType",
|
||||
"MarkdownHeaderTextSplitter",
|
||||
"MarkdownTextSplitter",
|
||||
"CharacterTextSplitter",
|
||||
"ExperimentalMarkdownSyntaxTextSplitter",
|
||||
"NLTKTextSplitter",
|
||||
"PythonCodeTextSplitter",
|
||||
"RecursiveCharacterTextSplitter",
|
||||
"RecursiveJsonSplitter",
|
||||
"SentenceTransformersTokenTextSplitter",
|
||||
"SpacyTextSplitter",
|
||||
"TextSplitter",
|
||||
"TokenTextSplitter",
|
||||
"Tokenizer",
|
||||
"split_text_on_tokens",
|
||||
]
|
||||
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
import copy
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Collection, Iterable, Sequence, Set
|
||||
from collections.abc import Collection, Iterable, Sequence
|
||||
from collections.abc import Set as AbstractSet
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
@ -47,10 +48,11 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
every document
|
||||
"""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||
f"({chunk_size}), should be smaller."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._length_function = length_function
|
||||
@ -96,7 +98,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
text = text.strip()
|
||||
if text == "":
|
||||
return None
|
||||
else:
|
||||
return text
|
||||
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
|
||||
@ -148,18 +149,20 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def _huggingface_tokenizer_length(text: str) -> int:
|
||||
return len(tokenizer.tokenize(text))
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Could not import transformers python package. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@ -167,7 +170,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
cls: type[TS],
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> TS:
|
||||
@ -175,11 +178,12 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
if model_name is not None:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
@ -220,7 +224,7 @@ class TokenTextSplitter(TextSplitter):
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -229,11 +233,12 @@ class TokenTextSplitter(TextSplitter):
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to for TokenTextSplitter. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
if model_name is not None:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
|
@ -60,9 +60,9 @@ def _split_text_with_regex(
|
||||
if len(_splits) % 2 == 0:
|
||||
splits += _splits[-1:]
|
||||
splits = (
|
||||
(splits + [_splits[-1]])
|
||||
([*splits, _splits[-1]])
|
||||
if keep_separator == "end"
|
||||
else ([_splits[0]] + splits)
|
||||
else ([_splits[0], *splits])
|
||||
)
|
||||
else:
|
||||
splits = re.split(separator, text)
|
||||
@ -170,7 +170,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
Returns:
|
||||
List[str]: A list of separators appropriate for the specified language.
|
||||
"""
|
||||
if language == Language.C or language == Language.CPP:
|
||||
if language in (Language.C, Language.CPP):
|
||||
return [
|
||||
# Split along class definitions
|
||||
"\nclass ",
|
||||
@ -191,7 +191,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.GO:
|
||||
if language == Language.GO:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfunc ",
|
||||
@ -209,7 +209,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.JAVA:
|
||||
if language == Language.JAVA:
|
||||
return [
|
||||
# Split along class definitions
|
||||
"\nclass ",
|
||||
@ -230,7 +230,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.KOTLIN:
|
||||
if language == Language.KOTLIN:
|
||||
return [
|
||||
# Split along class definitions
|
||||
"\nclass ",
|
||||
@ -256,7 +256,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.JS:
|
||||
if language == Language.JS:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfunction ",
|
||||
@ -277,7 +277,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.TS:
|
||||
if language == Language.TS:
|
||||
return [
|
||||
"\nenum ",
|
||||
"\ninterface ",
|
||||
@ -303,7 +303,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.PHP:
|
||||
if language == Language.PHP:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfunction ",
|
||||
@ -322,7 +322,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.PROTO:
|
||||
if language == Language.PROTO:
|
||||
return [
|
||||
# Split along message definitions
|
||||
"\nmessage ",
|
||||
@ -342,7 +342,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.PYTHON:
|
||||
if language == Language.PYTHON:
|
||||
return [
|
||||
# First, try to split along class definitions
|
||||
"\nclass ",
|
||||
@ -354,7 +354,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.RST:
|
||||
if language == Language.RST:
|
||||
return [
|
||||
# Split along section titles
|
||||
"\n=+\n",
|
||||
@ -368,7 +368,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.RUBY:
|
||||
if language == Language.RUBY:
|
||||
return [
|
||||
# Split along method definitions
|
||||
"\ndef ",
|
||||
@ -387,7 +387,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.ELIXIR:
|
||||
if language == Language.ELIXIR:
|
||||
return [
|
||||
# Split along method function and module definition
|
||||
"\ndef ",
|
||||
@ -411,7 +411,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.RUST:
|
||||
if language == Language.RUST:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfn ",
|
||||
@ -430,7 +430,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.SCALA:
|
||||
if language == Language.SCALA:
|
||||
return [
|
||||
# Split along class definitions
|
||||
"\nclass ",
|
||||
@ -451,7 +451,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.SWIFT:
|
||||
if language == Language.SWIFT:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfunc ",
|
||||
@ -472,7 +472,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.MARKDOWN:
|
||||
if language == Language.MARKDOWN:
|
||||
return [
|
||||
# First, try to split along Markdown headings (starting with level 2)
|
||||
"\n#{1,6} ",
|
||||
@ -492,7 +492,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.LATEX:
|
||||
if language == Language.LATEX:
|
||||
return [
|
||||
# First, try to split along Latex sections
|
||||
"\n\\\\chapter{",
|
||||
@ -516,7 +516,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.HTML:
|
||||
if language == Language.HTML:
|
||||
return [
|
||||
# First, try to split along HTML tags
|
||||
"<body",
|
||||
@ -548,7 +548,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
"<title",
|
||||
"",
|
||||
]
|
||||
elif language == Language.CSHARP:
|
||||
if language == Language.CSHARP:
|
||||
return [
|
||||
"\ninterface ",
|
||||
"\nenum ",
|
||||
@ -585,7 +585,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.SOL:
|
||||
if language == Language.SOL:
|
||||
return [
|
||||
# Split along compiler information definitions
|
||||
"\npragma ",
|
||||
@ -615,7 +615,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.COBOL:
|
||||
if language == Language.COBOL:
|
||||
return [
|
||||
# Split along divisions
|
||||
"\nIDENTIFICATION DIVISION.",
|
||||
@ -647,7 +647,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.LUA:
|
||||
if language == Language.LUA:
|
||||
return [
|
||||
# Split along variable and table definitions
|
||||
"\nlocal ",
|
||||
@ -664,7 +664,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.HASKELL:
|
||||
if language == Language.HASKELL:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nmain :: ",
|
||||
@ -703,7 +703,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language == Language.POWERSHELL:
|
||||
if language == Language.POWERSHELL:
|
||||
return [
|
||||
# Split along function definitions
|
||||
"\nfunction ",
|
||||
@ -727,10 +727,10 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
" ",
|
||||
"",
|
||||
]
|
||||
elif language in Language._value2member_map_:
|
||||
raise ValueError(f"Language {language} is not implemented yet!")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Language {language} is not supported! "
|
||||
f"Please choose from {list(Language)}"
|
||||
if language in Language._value2member_map_:
|
||||
msg = f"Language {language} is not implemented yet!"
|
||||
raise ValueError(msg)
|
||||
msg = (
|
||||
f"Language {language} is not supported! Please choose from {list(Language)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
@ -194,9 +194,10 @@ class HTMLHeaderTextSplitter:
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Unable to import BeautifulSoup. Please install via `pip install bs4`."
|
||||
) from e
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
body = soup.body if soup.body else soup
|
||||
@ -352,7 +353,7 @@ class HTMLSectionSplitter:
|
||||
for chunk in self.split_text(text):
|
||||
metadata = copy.deepcopy(_metadatas[i])
|
||||
|
||||
for key in chunk.metadata.keys():
|
||||
for key in chunk.metadata:
|
||||
if chunk.metadata[key] == "#TITLE#":
|
||||
chunk.metadata[key] = metadata["Title"]
|
||||
metadata = {**metadata, **chunk.metadata}
|
||||
@ -382,17 +383,16 @@ class HTMLSectionSplitter:
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4.element import PageElement
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import BeautifulSoup/PageElement, \
|
||||
msg = "Unable to import BeautifulSoup/PageElement, \
|
||||
please install with `pip install \
|
||||
bs4`."
|
||||
) from e
|
||||
raise ImportError(msg) from e
|
||||
|
||||
soup = BeautifulSoup(html_doc, "html.parser")
|
||||
headers = list(self.headers_to_split_on.keys())
|
||||
sections: list[dict[str, str | None]] = []
|
||||
|
||||
headers = soup.find_all(["body"] + headers) # type: ignore[assignment]
|
||||
headers = soup.find_all(["body", *headers]) # type: ignore[assignment]
|
||||
|
||||
for i, header in enumerate(headers):
|
||||
header_element = cast(PageElement, header)
|
||||
@ -441,9 +441,8 @@ class HTMLSectionSplitter:
|
||||
try:
|
||||
from lxml import etree
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import lxml, please install with `pip install lxml`."
|
||||
) from e
|
||||
msg = "Unable to import lxml, please install with `pip install lxml`."
|
||||
raise ImportError(msg) from e
|
||||
# use lxml library to parse html document and return xml ElementTree
|
||||
# Create secure parsers to prevent XXE attacks
|
||||
html_parser = etree.HTMLParser(no_network=True)
|
||||
@ -594,10 +593,11 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
|
||||
self._BeautifulSoup = BeautifulSoup
|
||||
self._Tag = Tag
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import BeautifulSoup. "
|
||||
"Please install it with 'pip install bs4'."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
self._headers_to_split_on = sorted(headers_to_split_on)
|
||||
self._max_chunk_size = max_chunk_size
|
||||
@ -646,9 +646,10 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
|
||||
nltk.download("stopwords")
|
||||
self._stopwords = set(nltk.corpus.stopwords.words(self._stopword_lang))
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import nltk. Please install it with 'pip install nltk'."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
def split_text(self, text: str) -> list[Document]:
|
||||
"""Splits the provided HTML text into smaller chunks based on the configuration.
|
||||
@ -927,7 +928,6 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
|
||||
content, preserved_elements
|
||||
)
|
||||
return [Document(page_content=page_content, metadata=metadata)]
|
||||
else:
|
||||
return self._further_split_chunk(content, metadata, preserved_elements)
|
||||
|
||||
def _further_split_chunk(
|
||||
|
@ -64,13 +64,12 @@ class RecursiveJsonSplitter:
|
||||
if isinstance(data, dict):
|
||||
# Process each key-value pair in the dictionary
|
||||
return {k: self._list_to_dict_preprocessing(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
if isinstance(data, list):
|
||||
# Convert the list to a dictionary with index-based keys
|
||||
return {
|
||||
str(i): self._list_to_dict_preprocessing(item)
|
||||
for i, item in enumerate(data)
|
||||
}
|
||||
else:
|
||||
# Base case: the item is neither a dict nor a list, so return it unchanged
|
||||
return data
|
||||
|
||||
@ -85,7 +84,7 @@ class RecursiveJsonSplitter:
|
||||
chunks = chunks if chunks is not None else [{}]
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
new_path = current_path + [key]
|
||||
new_path = [*current_path, key]
|
||||
chunk_size = self._json_size(chunks[-1])
|
||||
size = self._json_size({key: value})
|
||||
remaining = self.max_chunk_size - chunk_size
|
||||
|
@ -94,5 +94,4 @@ class JSFrameworkTextSplitter(RecursiveCharacterTextSplitter):
|
||||
+ ["<>", "\n\n", "&&\n", "||\n"]
|
||||
)
|
||||
self._separators = separators
|
||||
chunks = super().split_text(text)
|
||||
return chunks
|
||||
return super().split_text(text)
|
||||
|
@ -22,12 +22,11 @@ class KonlpyTextSplitter(TextSplitter):
|
||||
try:
|
||||
import konlpy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"""
|
||||
msg = """
|
||||
Konlpy is not installed, please install it with
|
||||
`pip install konlpy`
|
||||
"""
|
||||
)
|
||||
raise ImportError(msg)
|
||||
self.kkma = konlpy.tag.Kkma()
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
|
@ -121,8 +121,7 @@ class MarkdownHeaderTextSplitter:
|
||||
elif stripped_line.startswith("~~~"):
|
||||
in_code_block = True
|
||||
opening_fence = "~~~"
|
||||
else:
|
||||
if stripped_line.startswith(opening_fence):
|
||||
elif stripped_line.startswith(opening_fence):
|
||||
in_code_block = False
|
||||
opening_fence = ""
|
||||
|
||||
@ -207,7 +206,6 @@ class MarkdownHeaderTextSplitter:
|
||||
# aggregate these into chunks based on common metadata
|
||||
if not self.return_each_line:
|
||||
return self.aggregate_lines_to_chunks(lines_with_metadata)
|
||||
else:
|
||||
return [
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in lines_with_metadata
|
||||
|
@ -22,7 +22,8 @@ class NLTKTextSplitter(TextSplitter):
|
||||
self._language = language
|
||||
self._use_span_tokenize = use_span_tokenize
|
||||
if self._use_span_tokenize and self._separator != "":
|
||||
raise ValueError("When use_span_tokenize is True, separator should be ''")
|
||||
msg = "When use_span_tokenize is True, separator should be ''"
|
||||
raise ValueError(msg)
|
||||
try:
|
||||
import nltk
|
||||
|
||||
@ -31,9 +32,8 @@ class NLTKTextSplitter(TextSplitter):
|
||||
else:
|
||||
self._tokenizer = nltk.tokenize.sent_tokenize
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"NLTK is not installed, please install it with `pip install nltk`."
|
||||
)
|
||||
msg = "NLTK is not installed, please install it with `pip install nltk`."
|
||||
raise ImportError(msg)
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
|
@ -21,11 +21,12 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import sentence_transformers python package. "
|
||||
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
self.model_name = model_name
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
@ -43,12 +44,13 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
self.tokens_per_chunk = tokens_per_chunk
|
||||
|
||||
if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"The token limit of the models '{self.model_name}'"
|
||||
f" is: {self.maximum_tokens_per_chunk}."
|
||||
f" Argument tokens_per_chunk={self.tokens_per_chunk}"
|
||||
f" > maximum token limit."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Splits the input text into smaller components by splitting text on tokens.
|
||||
|
@ -46,9 +46,8 @@ def _make_spacy_pipeline_for_splitting(
|
||||
try:
|
||||
import spacy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Spacy is not installed, please install it with `pip install spacy`."
|
||||
)
|
||||
msg = "Spacy is not installed, please install it with `pip install spacy`."
|
||||
raise ImportError(msg)
|
||||
if pipeline == "sentencizer":
|
||||
sentencizer: Any = spacy.lang.en.English()
|
||||
sentencizer.add_pipe("sentencizer")
|
||||
|
@ -4,4 +4,3 @@ import pytest
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
||||
|
@ -14,7 +14,7 @@ def setup_module() -> None:
|
||||
nltk.download("punkt_tab")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def spacy() -> Any:
|
||||
try:
|
||||
import spacy
|
||||
|
@ -13,7 +13,7 @@ from langchain_text_splitters.sentence_transformers import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def sentence_transformers() -> Any:
|
||||
try:
|
||||
import sentence_transformers
|
||||
|
@ -45,7 +45,8 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
|
||||
only_core = config.getoption("--only-core") or False
|
||||
|
||||
if only_extended and only_core:
|
||||
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||
msg = "Cannot specify both `--only-extended` and `--only-core`."
|
||||
raise ValueError(msg)
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
@ -81,8 +82,5 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||
)
|
||||
elif only_extended:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not an extended test."))
|
||||
|
@ -98,7 +98,7 @@ def test_character_text_splitter_longer_words() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||
("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
|
||||
)
|
||||
def test_character_text_splitter_keep_separator_regex(
|
||||
separator: str, is_separator_regex: bool
|
||||
@ -120,7 +120,7 @@ def test_character_text_splitter_keep_separator_regex(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||
("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
|
||||
)
|
||||
def test_character_text_splitter_keep_separator_regex_start(
|
||||
separator: str, is_separator_regex: bool
|
||||
@ -142,7 +142,7 @@ def test_character_text_splitter_keep_separator_regex_start(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||
("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
|
||||
)
|
||||
def test_character_text_splitter_keep_separator_regex_end(
|
||||
separator: str, is_separator_regex: bool
|
||||
@ -164,7 +164,7 @@ def test_character_text_splitter_keep_separator_regex_end(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||
("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
|
||||
)
|
||||
def test_character_text_splitter_discard_separator_regex(
|
||||
separator: str, is_separator_regex: bool
|
||||
@ -250,7 +250,7 @@ def test_create_documents_with_metadata() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"splitter, text, expected_docs",
|
||||
("splitter", "text", "expected_docs"),
|
||||
[
|
||||
(
|
||||
CharacterTextSplitter(
|
||||
@ -1390,7 +1390,7 @@ def test_md_header_text_splitter_fenced_code_block(fence: str) -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["fence", "other_fence"], [("```", "~~~"), ("~~~", "```")])
|
||||
@pytest.mark.parametrize(("fence", "other_fence"), [("```", "~~~"), ("~~~", "```")])
|
||||
def test_md_header_text_splitter_fenced_code_block_interleaved(
|
||||
fence: str, other_fence: str
|
||||
) -> None:
|
||||
@ -2240,7 +2240,7 @@ def html_header_splitter_splitter_factory() -> Callable[
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"headers_to_split_on, html_input, expected_documents, test_case",
|
||||
("headers_to_split_on", "html_input", "expected_documents", "test_case"),
|
||||
[
|
||||
(
|
||||
# Test Case 1: Split on h1 and h2
|
||||
@ -2469,7 +2469,7 @@ def test_html_header_text_splitter(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"headers_to_split_on, html_content, expected_output, test_case",
|
||||
("headers_to_split_on", "html_content", "expected_output", "test_case"),
|
||||
[
|
||||
(
|
||||
# Test Case A: Split on h1 and h2 with h3 in content
|
||||
@ -2624,7 +2624,7 @@ def test_additional_html_header_text_splitter(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"headers_to_split_on, html_content, expected_output, test_case",
|
||||
("headers_to_split_on", "html_content", "expected_output", "test_case"),
|
||||
[
|
||||
(
|
||||
# Test Case C: Split on h1, h2, and h3 with no headers present
|
||||
@ -3551,7 +3551,7 @@ def test_character_text_splitter_discard_regex_separator_on_merge() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"separator,is_regex,text,chunk_size,expected",
|
||||
("separator", "is_regex", "text", "chunk_size", "expected"),
|
||||
[
|
||||
# 1) regex lookaround & split happens
|
||||
# "abcmiddef" split by "(?<=mid)" → ["abcmid","def"], chunk_size=5 keeps both
|
||||
|
Loading…
Reference in New Issue
Block a user