fix[experimental]: Fix text splitter with gradient (#26629)

Fixes #26221

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Tibor Reiss 2024-09-21 01:35:50 +02:00 committed by GitHub
parent 4ac9a6f52c
commit a8b24135a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 0 deletions

View File

@ -217,6 +217,12 @@ class SemanticChunker(BaseDocumentTransformer):
# np.percentile to fail.
if len(single_sentences_list) == 1:
return single_sentences_list
# similarly, the following np.gradient would fail
if (
self.breakpoint_threshold_type == "gradient"
and len(single_sentences_list) == 2
):
return single_sentences_list
distances, sentences = self._calculate_sentence_distances(single_sentences_list)
if self.number_of_chunks is not None:
breakpoint_distance_threshold = self._threshold_from_clusters(distances)

View File

@ -0,0 +1,54 @@
import re
from typing import List
import pytest
from langchain_core.embeddings import Embeddings
from langchain_experimental.text_splitter import SemanticChunker
FAKE_EMBEDDINGS = [
[0.02905, 0.42969, 0.65394, 0.62200],
[0.00515, 0.47214, 0.45327, 0.75605],
[0.57401, 0.30344, 0.41702, 0.63603],
[0.60308, 0.18708, 0.68871, 0.35634],
[0.52510, 0.56163, 0.34100, 0.54089],
[0.73275, 0.22089, 0.42652, 0.48204],
[0.47466, 0.26161, 0.79687, 0.26694],
]
SAMPLE_TEXT = (
"We need to harvest synergy effects viral engagement, but digitalize, "
"nor overcome key issues to meet key milestones. So digital literacy "
"where the metal hits the meat. So this vendor is incompetent. Can "
"you champion this? Let me diarize this. And we can synchronise "
"ourselves at a later timepoint t-shaped individual tread it daily. "
"That is a good problem"
)
class MockEmbeddings(Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return FAKE_EMBEDDINGS[: len(texts)]
def embed_query(self, text: str) -> List[float]:
return [1.0, 2.0]
@pytest.mark.parametrize(
"input_length, expected_length",
[
(1, 1),
(2, 2),
(5, 2),
],
)
def test_split_text_gradient(input_length: int, expected_length: int) -> None:
embeddings = MockEmbeddings()
chunker = SemanticChunker(
embeddings,
breakpoint_threshold_type="gradient",
)
list_of_sentences = re.split(r"(?<=[.?!])\s+", SAMPLE_TEXT)[:input_length]
chunks = chunker.split_text(" ".join(list_of_sentences))
assert len(chunks) == expected_length