mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 22:45:49 +00:00
text-splitters: Ruff autofixes (#31858)
Auto-fixes from ruff with rule `ALL`
This commit is contained in:
parent
8aed3b61a9
commit
451c90fefa
@ -54,28 +54,28 @@ from langchain_text_splitters.sentence_transformers import (
|
|||||||
from langchain_text_splitters.spacy import SpacyTextSplitter
|
from langchain_text_splitters.spacy import SpacyTextSplitter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TokenTextSplitter",
|
"CharacterTextSplitter",
|
||||||
"TextSplitter",
|
|
||||||
"Tokenizer",
|
|
||||||
"Language",
|
|
||||||
"RecursiveCharacterTextSplitter",
|
|
||||||
"RecursiveJsonSplitter",
|
|
||||||
"LatexTextSplitter",
|
|
||||||
"JSFrameworkTextSplitter",
|
|
||||||
"PythonCodeTextSplitter",
|
|
||||||
"KonlpyTextSplitter",
|
|
||||||
"SpacyTextSplitter",
|
|
||||||
"NLTKTextSplitter",
|
|
||||||
"split_text_on_tokens",
|
|
||||||
"SentenceTransformersTokenTextSplitter",
|
|
||||||
"ElementType",
|
"ElementType",
|
||||||
"HeaderType",
|
"ExperimentalMarkdownSyntaxTextSplitter",
|
||||||
"LineType",
|
|
||||||
"HTMLHeaderTextSplitter",
|
"HTMLHeaderTextSplitter",
|
||||||
"HTMLSectionSplitter",
|
"HTMLSectionSplitter",
|
||||||
"HTMLSemanticPreservingSplitter",
|
"HTMLSemanticPreservingSplitter",
|
||||||
|
"HeaderType",
|
||||||
|
"JSFrameworkTextSplitter",
|
||||||
|
"KonlpyTextSplitter",
|
||||||
|
"Language",
|
||||||
|
"LatexTextSplitter",
|
||||||
|
"LineType",
|
||||||
"MarkdownHeaderTextSplitter",
|
"MarkdownHeaderTextSplitter",
|
||||||
"MarkdownTextSplitter",
|
"MarkdownTextSplitter",
|
||||||
"CharacterTextSplitter",
|
"NLTKTextSplitter",
|
||||||
"ExperimentalMarkdownSyntaxTextSplitter",
|
"PythonCodeTextSplitter",
|
||||||
|
"RecursiveCharacterTextSplitter",
|
||||||
|
"RecursiveJsonSplitter",
|
||||||
|
"SentenceTransformersTokenTextSplitter",
|
||||||
|
"SpacyTextSplitter",
|
||||||
|
"TextSplitter",
|
||||||
|
"TokenTextSplitter",
|
||||||
|
"Tokenizer",
|
||||||
|
"split_text_on_tokens",
|
||||||
]
|
]
|
||||||
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Collection, Iterable, Sequence, Set
|
from collections.abc import Collection, Iterable, Sequence
|
||||||
|
from collections.abc import Set as AbstractSet
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -47,10 +48,11 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
every document
|
every document
|
||||||
"""
|
"""
|
||||||
if chunk_overlap > chunk_size:
|
if chunk_overlap > chunk_size:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||||
f"({chunk_size}), should be smaller."
|
f"({chunk_size}), should be smaller."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
self._chunk_size = chunk_size
|
self._chunk_size = chunk_size
|
||||||
self._chunk_overlap = chunk_overlap
|
self._chunk_overlap = chunk_overlap
|
||||||
self._length_function = length_function
|
self._length_function = length_function
|
||||||
@ -96,8 +98,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
text = text.strip()
|
text = text.strip()
|
||||||
if text == "":
|
if text == "":
|
||||||
return None
|
return None
|
||||||
else:
|
return text
|
||||||
return text
|
|
||||||
|
|
||||||
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
|
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
|
||||||
# We now want to combine these smaller pieces into medium size
|
# We now want to combine these smaller pieces into medium size
|
||||||
@ -148,18 +149,20 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
|
||||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
|
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
def _huggingface_tokenizer_length(text: str) -> int:
|
def _huggingface_tokenizer_length(text: str) -> int:
|
||||||
return len(tokenizer.tokenize(text))
|
return len(tokenizer.tokenize(text))
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Could not import transformers python package. "
|
"Could not import transformers python package. "
|
||||||
"Please install it with `pip install transformers`."
|
"Please install it with `pip install transformers`."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -167,7 +170,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
cls: type[TS],
|
cls: type[TS],
|
||||||
encoding_name: str = "gpt2",
|
encoding_name: str = "gpt2",
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> TS:
|
) -> TS:
|
||||||
@ -175,11 +178,12 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
try:
|
try:
|
||||||
import tiktoken
|
import tiktoken
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
msg = (
|
||||||
"Could not import tiktoken python package. "
|
"Could not import tiktoken python package. "
|
||||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||||
"Please install it with `pip install tiktoken`."
|
"Please install it with `pip install tiktoken`."
|
||||||
)
|
)
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
enc = tiktoken.encoding_for_model(model_name)
|
enc = tiktoken.encoding_for_model(model_name)
|
||||||
@ -220,7 +224,7 @@ class TokenTextSplitter(TextSplitter):
|
|||||||
self,
|
self,
|
||||||
encoding_name: str = "gpt2",
|
encoding_name: str = "gpt2",
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -229,11 +233,12 @@ class TokenTextSplitter(TextSplitter):
|
|||||||
try:
|
try:
|
||||||
import tiktoken
|
import tiktoken
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
msg = (
|
||||||
"Could not import tiktoken python package. "
|
"Could not import tiktoken python package. "
|
||||||
"This is needed in order to for TokenTextSplitter. "
|
"This is needed in order to for TokenTextSplitter. "
|
||||||
"Please install it with `pip install tiktoken`."
|
"Please install it with `pip install tiktoken`."
|
||||||
)
|
)
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
enc = tiktoken.encoding_for_model(model_name)
|
enc = tiktoken.encoding_for_model(model_name)
|
||||||
|
@ -60,9 +60,9 @@ def _split_text_with_regex(
|
|||||||
if len(_splits) % 2 == 0:
|
if len(_splits) % 2 == 0:
|
||||||
splits += _splits[-1:]
|
splits += _splits[-1:]
|
||||||
splits = (
|
splits = (
|
||||||
(splits + [_splits[-1]])
|
([*splits, _splits[-1]])
|
||||||
if keep_separator == "end"
|
if keep_separator == "end"
|
||||||
else ([_splits[0]] + splits)
|
else ([_splits[0], *splits])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
splits = re.split(separator, text)
|
splits = re.split(separator, text)
|
||||||
@ -170,7 +170,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
Returns:
|
Returns:
|
||||||
List[str]: A list of separators appropriate for the specified language.
|
List[str]: A list of separators appropriate for the specified language.
|
||||||
"""
|
"""
|
||||||
if language == Language.C or language == Language.CPP:
|
if language in (Language.C, Language.CPP):
|
||||||
return [
|
return [
|
||||||
# Split along class definitions
|
# Split along class definitions
|
||||||
"\nclass ",
|
"\nclass ",
|
||||||
@ -191,7 +191,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.GO:
|
if language == Language.GO:
|
||||||
return [
|
return [
|
||||||
# Split along function definitions
|
# Split along function definitions
|
||||||
"\nfunc ",
|
"\nfunc ",
|
||||||
@ -209,7 +209,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.JAVA:
|
if language == Language.JAVA:
|
||||||
return [
|
return [
|
||||||
# Split along class definitions
|
# Split along class definitions
|
||||||
"\nclass ",
|
"\nclass ",
|
||||||
@ -230,7 +230,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.KOTLIN:
|
if language == Language.KOTLIN:
|
||||||
return [
|
return [
|
||||||
# Split along class definitions
|
# Split along class definitions
|
||||||
"\nclass ",
|
"\nclass ",
|
||||||
@ -256,7 +256,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.JS:
|
if language == Language.JS:
|
||||||
return [
|
return [
|
||||||
# Split along function definitions
|
# Split along function definitions
|
||||||
"\nfunction ",
|
"\nfunction ",
|
||||||
@ -277,7 +277,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.TS:
|
if language == Language.TS:
|
||||||
return [
|
return [
|
||||||
"\nenum ",
|
"\nenum ",
|
||||||
"\ninterface ",
|
"\ninterface ",
|
||||||
@ -303,7 +303,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.PHP:
|
if language == Language.PHP:
|
||||||
return [
|
return [
|
||||||
# Split along function definitions
|
# Split along function definitions
|
||||||
"\nfunction ",
|
"\nfunction ",
|
||||||
@ -322,7 +322,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.PROTO:
|
if language == Language.PROTO:
|
||||||
return [
|
return [
|
||||||
# Split along message definitions
|
# Split along message definitions
|
||||||
"\nmessage ",
|
"\nmessage ",
|
||||||
@ -342,7 +342,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.PYTHON:
|
if language == Language.PYTHON:
|
||||||
return [
|
return [
|
||||||
# First, try to split along class definitions
|
# First, try to split along class definitions
|
||||||
"\nclass ",
|
"\nclass ",
|
||||||
@ -354,7 +354,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.RST:
|
if language == Language.RST:
|
||||||
return [
|
return [
|
||||||
# Split along section titles
|
# Split along section titles
|
||||||
"\n=+\n",
|
"\n=+\n",
|
||||||
@ -368,7 +368,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.RUBY:
|
if language == Language.RUBY:
|
||||||
return [
|
return [
|
||||||
# Split along method definitions
|
# Split along method definitions
|
||||||
"\ndef ",
|
"\ndef ",
|
||||||
@ -387,7 +387,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.ELIXIR:
|
if language == Language.ELIXIR:
|
||||||
return [
|
return [
|
||||||
# Split along method function and module definition
|
# Split along method function and module definition
|
||||||
"\ndef ",
|
"\ndef ",
|
||||||
@ -411,7 +411,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.RUST:
|
if language == Language.RUST:
|
||||||
return [
|
return [
|
||||||
# Split along function definitions
|
# Split along function definitions
|
||||||
"\nfn ",
|
"\nfn ",
|
||||||
@ -430,7 +430,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.SCALA:
|
if language == Language.SCALA:
|
||||||
return [
|
return [
|
||||||
# Split along class definitions
|
# Split along class definitions
|
||||||
"\nclass ",
|
"\nclass ",
|
||||||
@ -451,7 +451,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.SWIFT:
|
if language == Language.SWIFT:
|
||||||
return [
|
return [
|
||||||
# Split along function definitions
|
# Split along function definitions
|
||||||
"\nfunc ",
|
"\nfunc ",
|
||||||
@ -472,7 +472,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.MARKDOWN:
|
if language == Language.MARKDOWN:
|
||||||
return [
|
return [
|
||||||
# First, try to split along Markdown headings (starting with level 2)
|
# First, try to split along Markdown headings (starting with level 2)
|
||||||
"\n#{1,6} ",
|
"\n#{1,6} ",
|
||||||
@ -492,7 +492,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.LATEX:
|
if language == Language.LATEX:
|
||||||
return [
|
return [
|
||||||
# First, try to split along Latex sections
|
# First, try to split along Latex sections
|
||||||
"\n\\\\chapter{",
|
"\n\\\\chapter{",
|
||||||
@ -516,7 +516,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.HTML:
|
if language == Language.HTML:
|
||||||
return [
|
return [
|
||||||
# First, try to split along HTML tags
|
# First, try to split along HTML tags
|
||||||
"<body",
|
"<body",
|
||||||
@ -548,7 +548,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
"<title",
|
"<title",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.CSHARP:
|
if language == Language.CSHARP:
|
||||||
return [
|
return [
|
||||||
"\ninterface ",
|
"\ninterface ",
|
||||||
"\nenum ",
|
"\nenum ",
|
||||||
@ -585,7 +585,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.SOL:
|
if language == Language.SOL:
|
||||||
return [
|
return [
|
||||||
# Split along compiler information definitions
|
# Split along compiler information definitions
|
||||||
"\npragma ",
|
"\npragma ",
|
||||||
@ -615,7 +615,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.COBOL:
|
if language == Language.COBOL:
|
||||||
return [
|
return [
|
||||||
# Split along divisions
|
# Split along divisions
|
||||||
"\nIDENTIFICATION DIVISION.",
|
"\nIDENTIFICATION DIVISION.",
|
||||||
@ -647,7 +647,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.LUA:
|
if language == Language.LUA:
|
||||||
return [
|
return [
|
||||||
# Split along variable and table definitions
|
# Split along variable and table definitions
|
||||||
"\nlocal ",
|
"\nlocal ",
|
||||||
@ -664,7 +664,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.HASKELL:
|
if language == Language.HASKELL:
|
||||||
return [
|
return [
|
||||||
# Split along function definitions
|
# Split along function definitions
|
||||||
"\nmain :: ",
|
"\nmain :: ",
|
||||||
@ -703,7 +703,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language == Language.POWERSHELL:
|
if language == Language.POWERSHELL:
|
||||||
return [
|
return [
|
||||||
# Split along function definitions
|
# Split along function definitions
|
||||||
"\nfunction ",
|
"\nfunction ",
|
||||||
@ -727,10 +727,10 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
" ",
|
" ",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
elif language in Language._value2member_map_:
|
if language in Language._value2member_map_:
|
||||||
raise ValueError(f"Language {language} is not implemented yet!")
|
msg = f"Language {language} is not implemented yet!"
|
||||||
else:
|
raise ValueError(msg)
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Language {language} is not supported! "
|
f"Language {language} is not supported! Please choose from {list(Language)}"
|
||||||
f"Please choose from {list(Language)}"
|
)
|
||||||
)
|
raise ValueError(msg)
|
||||||
|
@ -194,9 +194,10 @@ class HTMLHeaderTextSplitter:
|
|||||||
try:
|
try:
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
msg = (
|
||||||
"Unable to import BeautifulSoup. Please install via `pip install bs4`."
|
"Unable to import BeautifulSoup. Please install via `pip install bs4`."
|
||||||
) from e
|
)
|
||||||
|
raise ImportError(msg) from e
|
||||||
|
|
||||||
soup = BeautifulSoup(html_content, "html.parser")
|
soup = BeautifulSoup(html_content, "html.parser")
|
||||||
body = soup.body if soup.body else soup
|
body = soup.body if soup.body else soup
|
||||||
@ -352,7 +353,7 @@ class HTMLSectionSplitter:
|
|||||||
for chunk in self.split_text(text):
|
for chunk in self.split_text(text):
|
||||||
metadata = copy.deepcopy(_metadatas[i])
|
metadata = copy.deepcopy(_metadatas[i])
|
||||||
|
|
||||||
for key in chunk.metadata.keys():
|
for key in chunk.metadata:
|
||||||
if chunk.metadata[key] == "#TITLE#":
|
if chunk.metadata[key] == "#TITLE#":
|
||||||
chunk.metadata[key] = metadata["Title"]
|
chunk.metadata[key] = metadata["Title"]
|
||||||
metadata = {**metadata, **chunk.metadata}
|
metadata = {**metadata, **chunk.metadata}
|
||||||
@ -382,17 +383,16 @@ class HTMLSectionSplitter:
|
|||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from bs4.element import PageElement
|
from bs4.element import PageElement
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
msg = "Unable to import BeautifulSoup/PageElement, \
|
||||||
"Unable to import BeautifulSoup/PageElement, \
|
|
||||||
please install with `pip install \
|
please install with `pip install \
|
||||||
bs4`."
|
bs4`."
|
||||||
) from e
|
raise ImportError(msg) from e
|
||||||
|
|
||||||
soup = BeautifulSoup(html_doc, "html.parser")
|
soup = BeautifulSoup(html_doc, "html.parser")
|
||||||
headers = list(self.headers_to_split_on.keys())
|
headers = list(self.headers_to_split_on.keys())
|
||||||
sections: list[dict[str, str | None]] = []
|
sections: list[dict[str, str | None]] = []
|
||||||
|
|
||||||
headers = soup.find_all(["body"] + headers) # type: ignore[assignment]
|
headers = soup.find_all(["body", *headers]) # type: ignore[assignment]
|
||||||
|
|
||||||
for i, header in enumerate(headers):
|
for i, header in enumerate(headers):
|
||||||
header_element = cast(PageElement, header)
|
header_element = cast(PageElement, header)
|
||||||
@ -441,9 +441,8 @@ class HTMLSectionSplitter:
|
|||||||
try:
|
try:
|
||||||
from lxml import etree
|
from lxml import etree
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
msg = "Unable to import lxml, please install with `pip install lxml`."
|
||||||
"Unable to import lxml, please install with `pip install lxml`."
|
raise ImportError(msg) from e
|
||||||
) from e
|
|
||||||
# use lxml library to parse html document and return xml ElementTree
|
# use lxml library to parse html document and return xml ElementTree
|
||||||
# Create secure parsers to prevent XXE attacks
|
# Create secure parsers to prevent XXE attacks
|
||||||
html_parser = etree.HTMLParser(no_network=True)
|
html_parser = etree.HTMLParser(no_network=True)
|
||||||
@ -594,10 +593,11 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
|
|||||||
self._BeautifulSoup = BeautifulSoup
|
self._BeautifulSoup = BeautifulSoup
|
||||||
self._Tag = Tag
|
self._Tag = Tag
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
msg = (
|
||||||
"Could not import BeautifulSoup. "
|
"Could not import BeautifulSoup. "
|
||||||
"Please install it with 'pip install bs4'."
|
"Please install it with 'pip install bs4'."
|
||||||
)
|
)
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
self._headers_to_split_on = sorted(headers_to_split_on)
|
self._headers_to_split_on = sorted(headers_to_split_on)
|
||||||
self._max_chunk_size = max_chunk_size
|
self._max_chunk_size = max_chunk_size
|
||||||
@ -646,9 +646,10 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
|
|||||||
nltk.download("stopwords")
|
nltk.download("stopwords")
|
||||||
self._stopwords = set(nltk.corpus.stopwords.words(self._stopword_lang))
|
self._stopwords = set(nltk.corpus.stopwords.words(self._stopword_lang))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
msg = (
|
||||||
"Could not import nltk. Please install it with 'pip install nltk'."
|
"Could not import nltk. Please install it with 'pip install nltk'."
|
||||||
)
|
)
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
def split_text(self, text: str) -> list[Document]:
|
def split_text(self, text: str) -> list[Document]:
|
||||||
"""Splits the provided HTML text into smaller chunks based on the configuration.
|
"""Splits the provided HTML text into smaller chunks based on the configuration.
|
||||||
@ -927,8 +928,7 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer):
|
|||||||
content, preserved_elements
|
content, preserved_elements
|
||||||
)
|
)
|
||||||
return [Document(page_content=page_content, metadata=metadata)]
|
return [Document(page_content=page_content, metadata=metadata)]
|
||||||
else:
|
return self._further_split_chunk(content, metadata, preserved_elements)
|
||||||
return self._further_split_chunk(content, metadata, preserved_elements)
|
|
||||||
|
|
||||||
def _further_split_chunk(
|
def _further_split_chunk(
|
||||||
self, content: str, metadata: dict[Any, Any], preserved_elements: dict[str, str]
|
self, content: str, metadata: dict[Any, Any], preserved_elements: dict[str, str]
|
||||||
|
@ -64,15 +64,14 @@ class RecursiveJsonSplitter:
|
|||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
# Process each key-value pair in the dictionary
|
# Process each key-value pair in the dictionary
|
||||||
return {k: self._list_to_dict_preprocessing(v) for k, v in data.items()}
|
return {k: self._list_to_dict_preprocessing(v) for k, v in data.items()}
|
||||||
elif isinstance(data, list):
|
if isinstance(data, list):
|
||||||
# Convert the list to a dictionary with index-based keys
|
# Convert the list to a dictionary with index-based keys
|
||||||
return {
|
return {
|
||||||
str(i): self._list_to_dict_preprocessing(item)
|
str(i): self._list_to_dict_preprocessing(item)
|
||||||
for i, item in enumerate(data)
|
for i, item in enumerate(data)
|
||||||
}
|
}
|
||||||
else:
|
# Base case: the item is neither a dict nor a list, so return it unchanged
|
||||||
# Base case: the item is neither a dict nor a list, so return it unchanged
|
return data
|
||||||
return data
|
|
||||||
|
|
||||||
def _json_split(
|
def _json_split(
|
||||||
self,
|
self,
|
||||||
@ -85,7 +84,7 @@ class RecursiveJsonSplitter:
|
|||||||
chunks = chunks if chunks is not None else [{}]
|
chunks = chunks if chunks is not None else [{}]
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
new_path = current_path + [key]
|
new_path = [*current_path, key]
|
||||||
chunk_size = self._json_size(chunks[-1])
|
chunk_size = self._json_size(chunks[-1])
|
||||||
size = self._json_size({key: value})
|
size = self._json_size({key: value})
|
||||||
remaining = self.max_chunk_size - chunk_size
|
remaining = self.max_chunk_size - chunk_size
|
||||||
|
@ -94,5 +94,4 @@ class JSFrameworkTextSplitter(RecursiveCharacterTextSplitter):
|
|||||||
+ ["<>", "\n\n", "&&\n", "||\n"]
|
+ ["<>", "\n\n", "&&\n", "||\n"]
|
||||||
)
|
)
|
||||||
self._separators = separators
|
self._separators = separators
|
||||||
chunks = super().split_text(text)
|
return super().split_text(text)
|
||||||
return chunks
|
|
||||||
|
@ -22,12 +22,11 @@ class KonlpyTextSplitter(TextSplitter):
|
|||||||
try:
|
try:
|
||||||
import konlpy
|
import konlpy
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
msg = """
|
||||||
"""
|
Konlpy is not installed, please install it with
|
||||||
Konlpy is not installed, please install it with
|
|
||||||
`pip install konlpy`
|
`pip install konlpy`
|
||||||
"""
|
"""
|
||||||
)
|
raise ImportError(msg)
|
||||||
self.kkma = konlpy.tag.Kkma()
|
self.kkma = konlpy.tag.Kkma()
|
||||||
|
|
||||||
def split_text(self, text: str) -> list[str]:
|
def split_text(self, text: str) -> list[str]:
|
||||||
|
@ -121,10 +121,9 @@ class MarkdownHeaderTextSplitter:
|
|||||||
elif stripped_line.startswith("~~~"):
|
elif stripped_line.startswith("~~~"):
|
||||||
in_code_block = True
|
in_code_block = True
|
||||||
opening_fence = "~~~"
|
opening_fence = "~~~"
|
||||||
else:
|
elif stripped_line.startswith(opening_fence):
|
||||||
if stripped_line.startswith(opening_fence):
|
in_code_block = False
|
||||||
in_code_block = False
|
opening_fence = ""
|
||||||
opening_fence = ""
|
|
||||||
|
|
||||||
if in_code_block:
|
if in_code_block:
|
||||||
current_content.append(stripped_line)
|
current_content.append(stripped_line)
|
||||||
@ -207,11 +206,10 @@ class MarkdownHeaderTextSplitter:
|
|||||||
# aggregate these into chunks based on common metadata
|
# aggregate these into chunks based on common metadata
|
||||||
if not self.return_each_line:
|
if not self.return_each_line:
|
||||||
return self.aggregate_lines_to_chunks(lines_with_metadata)
|
return self.aggregate_lines_to_chunks(lines_with_metadata)
|
||||||
else:
|
return [
|
||||||
return [
|
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
||||||
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
for chunk in lines_with_metadata
|
||||||
for chunk in lines_with_metadata
|
]
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class LineType(TypedDict):
|
class LineType(TypedDict):
|
||||||
|
@ -22,7 +22,8 @@ class NLTKTextSplitter(TextSplitter):
|
|||||||
self._language = language
|
self._language = language
|
||||||
self._use_span_tokenize = use_span_tokenize
|
self._use_span_tokenize = use_span_tokenize
|
||||||
if self._use_span_tokenize and self._separator != "":
|
if self._use_span_tokenize and self._separator != "":
|
||||||
raise ValueError("When use_span_tokenize is True, separator should be ''")
|
msg = "When use_span_tokenize is True, separator should be ''"
|
||||||
|
raise ValueError(msg)
|
||||||
try:
|
try:
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
@ -31,9 +32,8 @@ class NLTKTextSplitter(TextSplitter):
|
|||||||
else:
|
else:
|
||||||
self._tokenizer = nltk.tokenize.sent_tokenize
|
self._tokenizer = nltk.tokenize.sent_tokenize
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
msg = "NLTK is not installed, please install it with `pip install nltk`."
|
||||||
"NLTK is not installed, please install it with `pip install nltk`."
|
raise ImportError(msg)
|
||||||
)
|
|
||||||
|
|
||||||
def split_text(self, text: str) -> list[str]:
|
def split_text(self, text: str) -> list[str]:
|
||||||
"""Split incoming text and return chunks."""
|
"""Split incoming text and return chunks."""
|
||||||
|
@ -21,11 +21,12 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
|||||||
try:
|
try:
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
msg = (
|
||||||
"Could not import sentence_transformers python package. "
|
"Could not import sentence_transformers python package. "
|
||||||
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
||||||
"Please install it with `pip install sentence-transformers`."
|
"Please install it with `pip install sentence-transformers`."
|
||||||
)
|
)
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self._model = SentenceTransformer(self.model_name)
|
self._model = SentenceTransformer(self.model_name)
|
||||||
@ -43,12 +44,13 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
|||||||
self.tokens_per_chunk = tokens_per_chunk
|
self.tokens_per_chunk = tokens_per_chunk
|
||||||
|
|
||||||
if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
|
if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"The token limit of the models '{self.model_name}'"
|
f"The token limit of the models '{self.model_name}'"
|
||||||
f" is: {self.maximum_tokens_per_chunk}."
|
f" is: {self.maximum_tokens_per_chunk}."
|
||||||
f" Argument tokens_per_chunk={self.tokens_per_chunk}"
|
f" Argument tokens_per_chunk={self.tokens_per_chunk}"
|
||||||
f" > maximum token limit."
|
f" > maximum token limit."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
def split_text(self, text: str) -> list[str]:
|
def split_text(self, text: str) -> list[str]:
|
||||||
"""Splits the input text into smaller components by splitting text on tokens.
|
"""Splits the input text into smaller components by splitting text on tokens.
|
||||||
|
@ -46,9 +46,8 @@ def _make_spacy_pipeline_for_splitting(
|
|||||||
try:
|
try:
|
||||||
import spacy
|
import spacy
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
msg = "Spacy is not installed, please install it with `pip install spacy`."
|
||||||
"Spacy is not installed, please install it with `pip install spacy`."
|
raise ImportError(msg)
|
||||||
)
|
|
||||||
if pipeline == "sentencizer":
|
if pipeline == "sentencizer":
|
||||||
sentencizer: Any = spacy.lang.en.English()
|
sentencizer: Any = spacy.lang.en.English()
|
||||||
sentencizer.add_pipe("sentencizer")
|
sentencizer.add_pipe("sentencizer")
|
||||||
|
@ -4,4 +4,3 @@ import pytest
|
|||||||
@pytest.mark.compile
|
@pytest.mark.compile
|
||||||
def test_placeholder() -> None:
|
def test_placeholder() -> None:
|
||||||
"""Used for compiling integration tests without running any real tests."""
|
"""Used for compiling integration tests without running any real tests."""
|
||||||
pass
|
|
||||||
|
@ -14,7 +14,7 @@ def setup_module() -> None:
|
|||||||
nltk.download("punkt_tab")
|
nltk.download("punkt_tab")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def spacy() -> Any:
|
def spacy() -> Any:
|
||||||
try:
|
try:
|
||||||
import spacy
|
import spacy
|
||||||
|
@ -13,7 +13,7 @@ from langchain_text_splitters.sentence_transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def sentence_transformers() -> Any:
|
def sentence_transformers() -> Any:
|
||||||
try:
|
try:
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
@ -45,7 +45,8 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
|
|||||||
only_core = config.getoption("--only-core") or False
|
only_core = config.getoption("--only-core") or False
|
||||||
|
|
||||||
if only_extended and only_core:
|
if only_extended and only_core:
|
||||||
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
msg = "Cannot specify both `--only-extended` and `--only-core`."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
for item in items:
|
for item in items:
|
||||||
requires_marker = item.get_closest_marker("requires")
|
requires_marker = item.get_closest_marker("requires")
|
||||||
@ -81,8 +82,5 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
|
|||||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
elif only_extended:
|
||||||
if only_extended:
|
item.add_marker(pytest.mark.skip(reason="Skipping not an extended test."))
|
||||||
item.add_marker(
|
|
||||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
|
||||||
)
|
|
||||||
|
@ -98,7 +98,7 @@ def test_character_text_splitter_longer_words() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
|
||||||
)
|
)
|
||||||
def test_character_text_splitter_keep_separator_regex(
|
def test_character_text_splitter_keep_separator_regex(
|
||||||
separator: str, is_separator_regex: bool
|
separator: str, is_separator_regex: bool
|
||||||
@ -120,7 +120,7 @@ def test_character_text_splitter_keep_separator_regex(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
|
||||||
)
|
)
|
||||||
def test_character_text_splitter_keep_separator_regex_start(
|
def test_character_text_splitter_keep_separator_regex_start(
|
||||||
separator: str, is_separator_regex: bool
|
separator: str, is_separator_regex: bool
|
||||||
@ -142,7 +142,7 @@ def test_character_text_splitter_keep_separator_regex_start(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
|
||||||
)
|
)
|
||||||
def test_character_text_splitter_keep_separator_regex_end(
|
def test_character_text_splitter_keep_separator_regex_end(
|
||||||
separator: str, is_separator_regex: bool
|
separator: str, is_separator_regex: bool
|
||||||
@ -164,7 +164,7 @@ def test_character_text_splitter_keep_separator_regex_end(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)]
|
||||||
)
|
)
|
||||||
def test_character_text_splitter_discard_separator_regex(
|
def test_character_text_splitter_discard_separator_regex(
|
||||||
separator: str, is_separator_regex: bool
|
separator: str, is_separator_regex: bool
|
||||||
@ -250,7 +250,7 @@ def test_create_documents_with_metadata() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"splitter, text, expected_docs",
|
("splitter", "text", "expected_docs"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
CharacterTextSplitter(
|
CharacterTextSplitter(
|
||||||
@ -1390,7 +1390,7 @@ def test_md_header_text_splitter_fenced_code_block(fence: str) -> None:
|
|||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(["fence", "other_fence"], [("```", "~~~"), ("~~~", "```")])
|
@pytest.mark.parametrize(("fence", "other_fence"), [("```", "~~~"), ("~~~", "```")])
|
||||||
def test_md_header_text_splitter_fenced_code_block_interleaved(
|
def test_md_header_text_splitter_fenced_code_block_interleaved(
|
||||||
fence: str, other_fence: str
|
fence: str, other_fence: str
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -2240,7 +2240,7 @@ def html_header_splitter_splitter_factory() -> Callable[
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"headers_to_split_on, html_input, expected_documents, test_case",
|
("headers_to_split_on", "html_input", "expected_documents", "test_case"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
# Test Case 1: Split on h1 and h2
|
# Test Case 1: Split on h1 and h2
|
||||||
@ -2469,7 +2469,7 @@ def test_html_header_text_splitter(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"headers_to_split_on, html_content, expected_output, test_case",
|
("headers_to_split_on", "html_content", "expected_output", "test_case"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
# Test Case A: Split on h1 and h2 with h3 in content
|
# Test Case A: Split on h1 and h2 with h3 in content
|
||||||
@ -2624,7 +2624,7 @@ def test_additional_html_header_text_splitter(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"headers_to_split_on, html_content, expected_output, test_case",
|
("headers_to_split_on", "html_content", "expected_output", "test_case"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
# Test Case C: Split on h1, h2, and h3 with no headers present
|
# Test Case C: Split on h1, h2, and h3 with no headers present
|
||||||
@ -3551,7 +3551,7 @@ def test_character_text_splitter_discard_regex_separator_on_merge() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"separator,is_regex,text,chunk_size,expected",
|
("separator", "is_regex", "text", "chunk_size", "expected"),
|
||||||
[
|
[
|
||||||
# 1) regex lookaround & split happens
|
# 1) regex lookaround & split happens
|
||||||
# "abcmiddef" split by "(?<=mid)" → ["abcmid","def"], chunk_size=5 keeps both
|
# "abcmiddef" split by "(?<=mid)" → ["abcmid","def"], chunk_size=5 keeps both
|
||||||
|
Loading…
Reference in New Issue
Block a user