diff --git a/libs/text-splitters/langchain_text_splitters/__init__.py b/libs/text-splitters/langchain_text_splitters/__init__.py index 2bcc8d0731e..030d948d898 100644 --- a/libs/text-splitters/langchain_text_splitters/__init__.py +++ b/libs/text-splitters/langchain_text_splitters/__init__.py @@ -54,28 +54,28 @@ from langchain_text_splitters.sentence_transformers import ( from langchain_text_splitters.spacy import SpacyTextSplitter __all__ = [ - "TokenTextSplitter", - "TextSplitter", - "Tokenizer", - "Language", - "RecursiveCharacterTextSplitter", - "RecursiveJsonSplitter", - "LatexTextSplitter", - "JSFrameworkTextSplitter", - "PythonCodeTextSplitter", - "KonlpyTextSplitter", - "SpacyTextSplitter", - "NLTKTextSplitter", - "split_text_on_tokens", - "SentenceTransformersTokenTextSplitter", + "CharacterTextSplitter", "ElementType", - "HeaderType", - "LineType", + "ExperimentalMarkdownSyntaxTextSplitter", "HTMLHeaderTextSplitter", "HTMLSectionSplitter", "HTMLSemanticPreservingSplitter", + "HeaderType", + "JSFrameworkTextSplitter", + "KonlpyTextSplitter", + "Language", + "LatexTextSplitter", + "LineType", "MarkdownHeaderTextSplitter", "MarkdownTextSplitter", - "CharacterTextSplitter", - "ExperimentalMarkdownSyntaxTextSplitter", + "NLTKTextSplitter", + "PythonCodeTextSplitter", + "RecursiveCharacterTextSplitter", + "RecursiveJsonSplitter", + "SentenceTransformersTokenTextSplitter", + "SpacyTextSplitter", + "TextSplitter", + "TokenTextSplitter", + "Tokenizer", + "split_text_on_tokens", ] diff --git a/libs/text-splitters/langchain_text_splitters/base.py b/libs/text-splitters/langchain_text_splitters/base.py index ac115e8f94f..f9e4a92222a 100644 --- a/libs/text-splitters/langchain_text_splitters/base.py +++ b/libs/text-splitters/langchain_text_splitters/base.py @@ -3,7 +3,8 @@ from __future__ import annotations import copy import logging from abc import ABC, abstractmethod -from collections.abc import Collection, Iterable, Sequence, Set +from collections.abc import Collection, Iterable, Sequence +from collections.abc import Set as AbstractSet from dataclasses import dataclass from enum import Enum from typing import ( @@ -47,10 +48,11 @@ class TextSplitter(BaseDocumentTransformer, ABC): every document """ if chunk_overlap > chunk_size: - raise ValueError( + msg = ( f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) + raise ValueError(msg) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._length_function = length_function @@ -96,8 +98,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): text = text.strip() if text == "": return None - else: - return text + return text def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: # We now want to combine these smaller pieces into medium size @@ -148,18 +149,20 @@ class TextSplitter(BaseDocumentTransformer, ABC): from transformers.tokenization_utils_base import PreTrainedTokenizerBase if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( + msg = ( "Tokenizer received was not an instance of PreTrainedTokenizerBase" ) + raise ValueError(msg) def _huggingface_tokenizer_length(text: str) -> int: return len(tokenizer.tokenize(text)) except ImportError: - raise ValueError( + msg = ( "Could not import transformers python package. " "Please install it with `pip install transformers`." ) + raise ValueError(msg) return cls(length_function=_huggingface_tokenizer_length, **kwargs) @classmethod @@ -167,7 +170,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): cls: type[TS], encoding_name: str = "gpt2", model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ) -> TS: @@ -175,11 +178,12 @@ class TextSplitter(BaseDocumentTransformer, ABC): try: import tiktoken except ImportError: - raise ImportError( + msg = ( "Could not import tiktoken python package. " "This is needed in order to calculate max_tokens_for_prompt. " "Please install it with `pip install tiktoken`." ) + raise ImportError(msg) if model_name is not None: enc = tiktoken.encoding_for_model(model_name) @@ -220,7 +224,7 @@ class TokenTextSplitter(TextSplitter): self, encoding_name: str = "gpt2", model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ) -> None: @@ -229,11 +233,12 @@ class TokenTextSplitter(TextSplitter): try: import tiktoken except ImportError: - raise ImportError( + msg = ( "Could not import tiktoken python package. " "This is needed in order to for TokenTextSplitter. " "Please install it with `pip install tiktoken`." ) + raise ImportError(msg) if model_name is not None: enc = tiktoken.encoding_for_model(model_name) diff --git a/libs/text-splitters/langchain_text_splitters/character.py b/libs/text-splitters/langchain_text_splitters/character.py index 0060a6462f9..517cea5f640 100644 --- a/libs/text-splitters/langchain_text_splitters/character.py +++ b/libs/text-splitters/langchain_text_splitters/character.py @@ -60,9 +60,9 @@ def _split_text_with_regex( if len(_splits) % 2 == 0: splits += _splits[-1:] splits = ( - (splits + [_splits[-1]]) + ([*splits, _splits[-1]]) if keep_separator == "end" - else ([_splits[0]] + splits) + else ([_splits[0], *splits]) ) else: splits = re.split(separator, text) @@ -170,7 +170,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): Returns: List[str]: A list of separators appropriate for the specified language. """ - if language == Language.C or language == Language.CPP: + if language in (Language.C, Language.CPP): return [ # Split along class definitions "\nclass ", @@ -191,7 +191,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.GO: + if language == Language.GO: return [ # Split along function definitions "\nfunc ", @@ -209,7 +209,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.JAVA: + if language == Language.JAVA: return [ # Split along class definitions "\nclass ", @@ -230,7 +230,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.KOTLIN: + if language == Language.KOTLIN: return [ # Split along class definitions "\nclass ", @@ -256,7 +256,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.JS: + if language == Language.JS: return [ # Split along function definitions "\nfunction ", @@ -277,7 +277,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.TS: + if language == Language.TS: return [ "\nenum ", "\ninterface ", @@ -303,7 +303,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.PHP: + if language == Language.PHP: return [ # Split along function definitions "\nfunction ", @@ -322,7 +322,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.PROTO: + if language == Language.PROTO: return [ # Split along message definitions "\nmessage ", @@ -342,7 +342,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.PYTHON: + if language == Language.PYTHON: return [ # First, try to split along class definitions "\nclass ", @@ -354,7 +354,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.RST: + if language == Language.RST: return [ # Split along section titles "\n=+\n", @@ -368,7 +368,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.RUBY: + if language == Language.RUBY: return [ # Split along method definitions "\ndef ", @@ -387,7 +387,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.ELIXIR: + if language == Language.ELIXIR: return [ # Split along method function and module definition "\ndef ", @@ -411,7 +411,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.RUST: + if language == Language.RUST: return [ # Split along function definitions "\nfn ", @@ -430,7 +430,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.SCALA: + if language == Language.SCALA: return [ # Split along class definitions "\nclass ", @@ -451,7 +451,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.SWIFT: + if language == Language.SWIFT: return [ # Split along function definitions "\nfunc ", @@ -472,7 +472,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.MARKDOWN: + if language == Language.MARKDOWN: return [ # First, try to split along Markdown headings (starting with level 2) "\n#{1,6} ", @@ -492,7 +492,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.LATEX: + if language == Language.LATEX: return [ # First, try to split along Latex sections "\n\\\\chapter{", @@ -516,7 +516,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - elif language == Language.HTML: + if language == Language.HTML: return [ # First, try to split along HTML tags " list[Document]: """Splits the provided HTML text into smaller chunks based on the configuration. @@ -927,8 +928,7 @@ class HTMLSemanticPreservingSplitter(BaseDocumentTransformer): content, preserved_elements ) return [Document(page_content=page_content, metadata=metadata)] - else: - return self._further_split_chunk(content, metadata, preserved_elements) + return self._further_split_chunk(content, metadata, preserved_elements) def _further_split_chunk( self, content: str, metadata: dict[Any, Any], preserved_elements: dict[str, str] diff --git a/libs/text-splitters/langchain_text_splitters/json.py b/libs/text-splitters/langchain_text_splitters/json.py index 0f71942e9a4..a5c9385febd 100644 --- a/libs/text-splitters/langchain_text_splitters/json.py +++ b/libs/text-splitters/langchain_text_splitters/json.py @@ -64,15 +64,14 @@ class RecursiveJsonSplitter: if isinstance(data, dict): # Process each key-value pair in the dictionary return {k: self._list_to_dict_preprocessing(v) for k, v in data.items()} - elif isinstance(data, list): + if isinstance(data, list): # Convert the list to a dictionary with index-based keys return { str(i): self._list_to_dict_preprocessing(item) for i, item in enumerate(data) } - else: - # Base case: the item is neither a dict nor a list, so return it unchanged - return data + # Base case: the item is neither a dict nor a list, so return it unchanged + return data def _json_split( self, @@ -85,7 +84,7 @@ class RecursiveJsonSplitter: chunks = chunks if chunks is not None else [{}] if isinstance(data, dict): for key, value in data.items(): - new_path = current_path + [key] + new_path = [*current_path, key] chunk_size = self._json_size(chunks[-1]) size = self._json_size({key: value}) remaining = self.max_chunk_size - chunk_size diff --git a/libs/text-splitters/langchain_text_splitters/jsx.py b/libs/text-splitters/langchain_text_splitters/jsx.py index 3c0b73ebd28..596bf0463e7 100644 --- a/libs/text-splitters/langchain_text_splitters/jsx.py +++ b/libs/text-splitters/langchain_text_splitters/jsx.py @@ -94,5 +94,4 @@ class JSFrameworkTextSplitter(RecursiveCharacterTextSplitter): + ["<>", "\n\n", "&&\n", "||\n"] ) self._separators = separators - chunks = super().split_text(text) - return chunks + return super().split_text(text) diff --git a/libs/text-splitters/langchain_text_splitters/konlpy.py b/libs/text-splitters/langchain_text_splitters/konlpy.py index 60b35091677..2ffd9b0e8df 100644 --- a/libs/text-splitters/langchain_text_splitters/konlpy.py +++ b/libs/text-splitters/langchain_text_splitters/konlpy.py @@ -22,12 +22,11 @@ class KonlpyTextSplitter(TextSplitter): try: import konlpy except ImportError: - raise ImportError( - """ - Konlpy is not installed, please install it with + msg = """ + Konlpy is not installed, please install it with `pip install konlpy` """ - ) + raise ImportError(msg) self.kkma = konlpy.tag.Kkma() def split_text(self, text: str) -> list[str]: diff --git a/libs/text-splitters/langchain_text_splitters/markdown.py b/libs/text-splitters/langchain_text_splitters/markdown.py index bbf10828ed4..3d60a9f7269 100644 --- a/libs/text-splitters/langchain_text_splitters/markdown.py +++ b/libs/text-splitters/langchain_text_splitters/markdown.py @@ -121,10 +121,9 @@ class MarkdownHeaderTextSplitter: elif stripped_line.startswith("~~~"): in_code_block = True opening_fence = "~~~" - else: - if stripped_line.startswith(opening_fence): - in_code_block = False - opening_fence = "" + elif stripped_line.startswith(opening_fence): + in_code_block = False + opening_fence = "" if in_code_block: current_content.append(stripped_line) @@ -207,11 +206,10 @@ class MarkdownHeaderTextSplitter: # aggregate these into chunks based on common metadata if not self.return_each_line: return self.aggregate_lines_to_chunks(lines_with_metadata) - else: - return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in lines_with_metadata - ] + return [ + Document(page_content=chunk["content"], metadata=chunk["metadata"]) + for chunk in lines_with_metadata + ] class LineType(TypedDict): diff --git a/libs/text-splitters/langchain_text_splitters/nltk.py b/libs/text-splitters/langchain_text_splitters/nltk.py index 931e7b8cf3e..7b0573fbc46 100644 --- a/libs/text-splitters/langchain_text_splitters/nltk.py +++ b/libs/text-splitters/langchain_text_splitters/nltk.py @@ -22,7 +22,8 @@ class NLTKTextSplitter(TextSplitter): self._language = language self._use_span_tokenize = use_span_tokenize if self._use_span_tokenize and self._separator != "": - raise ValueError("When use_span_tokenize is True, separator should be ''") + msg = "When use_span_tokenize is True, separator should be ''" + raise ValueError(msg) try: import nltk @@ -31,9 +32,8 @@ class NLTKTextSplitter(TextSplitter): else: self._tokenizer = nltk.tokenize.sent_tokenize except ImportError: - raise ImportError( - "NLTK is not installed, please install it with `pip install nltk`." - ) + msg = "NLTK is not installed, please install it with `pip install nltk`." + raise ImportError(msg) def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" diff --git a/libs/text-splitters/langchain_text_splitters/sentence_transformers.py b/libs/text-splitters/langchain_text_splitters/sentence_transformers.py index b3c88331d96..33e27a53cb5 100644 --- a/libs/text-splitters/langchain_text_splitters/sentence_transformers.py +++ b/libs/text-splitters/langchain_text_splitters/sentence_transformers.py @@ -21,11 +21,12 @@ class SentenceTransformersTokenTextSplitter(TextSplitter): try: from sentence_transformers import SentenceTransformer except ImportError: - raise ImportError( + msg = ( "Could not import sentence_transformers python package. " "This is needed in order to for SentenceTransformersTokenTextSplitter. " "Please install it with `pip install sentence-transformers`." ) + raise ImportError(msg) self.model_name = model_name self._model = SentenceTransformer(self.model_name) @@ -43,12 +44,13 @@ class SentenceTransformersTokenTextSplitter(TextSplitter): self.tokens_per_chunk = tokens_per_chunk if self.tokens_per_chunk > self.maximum_tokens_per_chunk: - raise ValueError( + msg = ( f"The token limit of the models '{self.model_name}'" f" is: {self.maximum_tokens_per_chunk}." f" Argument tokens_per_chunk={self.tokens_per_chunk}" f" > maximum token limit." ) + raise ValueError(msg) def split_text(self, text: str) -> list[str]: """Splits the input text into smaller components by splitting text on tokens. diff --git a/libs/text-splitters/langchain_text_splitters/spacy.py b/libs/text-splitters/langchain_text_splitters/spacy.py index 4d39caab398..df65aefa0f6 100644 --- a/libs/text-splitters/langchain_text_splitters/spacy.py +++ b/libs/text-splitters/langchain_text_splitters/spacy.py @@ -46,9 +46,8 @@ def _make_spacy_pipeline_for_splitting( try: import spacy except ImportError: - raise ImportError( - "Spacy is not installed, please install it with `pip install spacy`." - ) + msg = "Spacy is not installed, please install it with `pip install spacy`." + raise ImportError(msg) if pipeline == "sentencizer": sentencizer: Any = spacy.lang.en.English() sentencizer.add_pipe("sentencizer") diff --git a/libs/text-splitters/tests/integration_tests/test_compile.py b/libs/text-splitters/tests/integration_tests/test_compile.py index 33ecccdfa0f..f315e45f521 100644 --- a/libs/text-splitters/tests/integration_tests/test_compile.py +++ b/libs/text-splitters/tests/integration_tests/test_compile.py @@ -4,4 +4,3 @@ import pytest @pytest.mark.compile def test_placeholder() -> None: """Used for compiling integration tests without running any real tests.""" - pass diff --git a/libs/text-splitters/tests/integration_tests/test_nlp_text_splitters.py b/libs/text-splitters/tests/integration_tests/test_nlp_text_splitters.py index 79e20455ae2..8ba6a268d33 100644 --- a/libs/text-splitters/tests/integration_tests/test_nlp_text_splitters.py +++ b/libs/text-splitters/tests/integration_tests/test_nlp_text_splitters.py @@ -14,7 +14,7 @@ def setup_module() -> None: nltk.download("punkt_tab") -@pytest.fixture() +@pytest.fixture def spacy() -> Any: try: import spacy diff --git a/libs/text-splitters/tests/integration_tests/test_text_splitter.py b/libs/text-splitters/tests/integration_tests/test_text_splitter.py index 94f687ba3b2..ad0549e2424 100644 --- a/libs/text-splitters/tests/integration_tests/test_text_splitter.py +++ b/libs/text-splitters/tests/integration_tests/test_text_splitter.py @@ -13,7 +13,7 @@ from langchain_text_splitters.sentence_transformers import ( ) -@pytest.fixture() +@pytest.fixture def sentence_transformers() -> Any: try: import sentence_transformers diff --git a/libs/text-splitters/tests/unit_tests/conftest.py b/libs/text-splitters/tests/unit_tests/conftest.py index f6219faaa18..83c9f53b64d 100644 --- a/libs/text-splitters/tests/unit_tests/conftest.py +++ b/libs/text-splitters/tests/unit_tests/conftest.py @@ -45,7 +45,8 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> only_core = config.getoption("--only-core") or False if only_extended and only_core: - raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") + msg = "Cannot specify both `--only-extended` and `--only-core`." + raise ValueError(msg) for item in items: requires_marker = item.get_closest_marker("requires") @@ -81,8 +82,5 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") ) break - else: - if only_extended: - item.add_marker( - pytest.mark.skip(reason="Skipping not an extended test.") - ) + elif only_extended: + item.add_marker(pytest.mark.skip(reason="Skipping not an extended test.")) diff --git a/libs/text-splitters/tests/unit_tests/test_text_splitters.py b/libs/text-splitters/tests/unit_tests/test_text_splitters.py index 935092f56c5..6f03a2d59f0 100644 --- a/libs/text-splitters/tests/unit_tests/test_text_splitters.py +++ b/libs/text-splitters/tests/unit_tests/test_text_splitters.py @@ -98,7 +98,7 @@ def test_character_text_splitter_longer_words() -> None: @pytest.mark.parametrize( - "separator, is_separator_regex", [(re.escape("."), True), (".", False)] + ("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)] ) def test_character_text_splitter_keep_separator_regex( separator: str, is_separator_regex: bool @@ -120,7 +120,7 @@ def test_character_text_splitter_keep_separator_regex( @pytest.mark.parametrize( - "separator, is_separator_regex", [(re.escape("."), True), (".", False)] + ("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)] ) def test_character_text_splitter_keep_separator_regex_start( separator: str, is_separator_regex: bool @@ -142,7 +142,7 @@ def test_character_text_splitter_keep_separator_regex_start( @pytest.mark.parametrize( - "separator, is_separator_regex", [(re.escape("."), True), (".", False)] + ("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)] ) def test_character_text_splitter_keep_separator_regex_end( separator: str, is_separator_regex: bool @@ -164,7 +164,7 @@ def test_character_text_splitter_keep_separator_regex_end( @pytest.mark.parametrize( - "separator, is_separator_regex", [(re.escape("."), True), (".", False)] + ("separator", "is_separator_regex"), [(re.escape("."), True), (".", False)] ) def test_character_text_splitter_discard_separator_regex( separator: str, is_separator_regex: bool @@ -250,7 +250,7 @@ def test_create_documents_with_metadata() -> None: @pytest.mark.parametrize( - "splitter, text, expected_docs", + ("splitter", "text", "expected_docs"), [ ( CharacterTextSplitter( @@ -1390,7 +1390,7 @@ def test_md_header_text_splitter_fenced_code_block(fence: str) -> None: assert output == expected_output -@pytest.mark.parametrize(["fence", "other_fence"], [("```", "~~~"), ("~~~", "```")]) +@pytest.mark.parametrize(("fence", "other_fence"), [("```", "~~~"), ("~~~", "```")]) def test_md_header_text_splitter_fenced_code_block_interleaved( fence: str, other_fence: str ) -> None: @@ -2240,7 +2240,7 @@ def html_header_splitter_splitter_factory() -> Callable[ @pytest.mark.parametrize( - "headers_to_split_on, html_input, expected_documents, test_case", + ("headers_to_split_on", "html_input", "expected_documents", "test_case"), [ ( # Test Case 1: Split on h1 and h2 @@ -2469,7 +2469,7 @@ def test_html_header_text_splitter( @pytest.mark.parametrize( - "headers_to_split_on, html_content, expected_output, test_case", + ("headers_to_split_on", "html_content", "expected_output", "test_case"), [ ( # Test Case A: Split on h1 and h2 with h3 in content @@ -2624,7 +2624,7 @@ def test_additional_html_header_text_splitter( @pytest.mark.parametrize( - "headers_to_split_on, html_content, expected_output, test_case", + ("headers_to_split_on", "html_content", "expected_output", "test_case"), [ ( # Test Case C: Split on h1, h2, and h3 with no headers present @@ -3551,7 +3551,7 @@ def test_character_text_splitter_discard_regex_separator_on_merge() -> None: @pytest.mark.parametrize( - "separator,is_regex,text,chunk_size,expected", + ("separator", "is_regex", "text", "chunk_size", "expected"), [ # 1) regex lookaround & split happens # "abcmiddef" split by "(?<=mid)" → ["abcmid","def"], chunk_size=5 keeps both