mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-21 12:01:47 +00:00
Experimental: Add other threshold types to SemanticChunker (#16807)
**Description** Adding different threshold types to the semantic chunker. I’ve had much better and predictable performance when using standard deviations instead of percentiles.  For all the documents I’ve tried, the distribution of distances look similar to the above: positively skewed normal distribution. All skews I’ve seen are less than 1 so that explains why standard deviations perform well, but I’ve included IQR if anyone wants something more robust. Also, using the percentile method backwards, you can declare the number of clusters and use semantic chunking to get an ‘optimal’ splitting. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
ce682f5a09
commit
a4896da2a0
File diff suppressed because one or more lines are too long
@ -1,6 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
from typing import Any, Iterable, List, Optional, Sequence, Tuple
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain_community.utils.math import (
|
from langchain_community.utils.math import (
|
||||||
@ -83,6 +83,14 @@ def calculate_cosine_distances(sentences: List[dict]) -> Tuple[List[float], List
|
|||||||
return distances, sentences
|
return distances, sentences
|
||||||
|
|
||||||
|
|
||||||
|
BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"]
|
||||||
|
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = {
|
||||||
|
"percentile": 95,
|
||||||
|
"standard_deviation": 3,
|
||||||
|
"interquartile": 1.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SemanticChunker(BaseDocumentTransformer):
|
class SemanticChunker(BaseDocumentTransformer):
|
||||||
"""Split the text based on semantic similarity.
|
"""Split the text based on semantic similarity.
|
||||||
|
|
||||||
@ -95,12 +103,89 @@ class SemanticChunker(BaseDocumentTransformer):
|
|||||||
sentences, and then merges one that are similar in the embedding space.
|
sentences, and then merges one that are similar in the embedding space.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embeddings: Embeddings, add_start_index: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
embeddings: Embeddings,
|
||||||
|
add_start_index: bool = False,
|
||||||
|
breakpoint_threshold_type: BreakpointThresholdType = "percentile",
|
||||||
|
breakpoint_threshold_amount: Optional[float] = None,
|
||||||
|
number_of_chunks: Optional[int] = None,
|
||||||
|
):
|
||||||
self._add_start_index = add_start_index
|
self._add_start_index = add_start_index
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
|
self.breakpoint_threshold_type = breakpoint_threshold_type
|
||||||
|
self.number_of_chunks = number_of_chunks
|
||||||
|
if breakpoint_threshold_amount is None:
|
||||||
|
self.breakpoint_threshold_amount = BREAKPOINT_DEFAULTS[
|
||||||
|
breakpoint_threshold_type
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.breakpoint_threshold_amount = breakpoint_threshold_amount
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
def _calculate_breakpoint_threshold(self, distances: List[float]) -> float:
|
||||||
|
if self.breakpoint_threshold_type == "percentile":
|
||||||
|
return cast(
|
||||||
|
float,
|
||||||
|
np.percentile(distances, self.breakpoint_threshold_amount),
|
||||||
|
)
|
||||||
|
elif self.breakpoint_threshold_type == "standard_deviation":
|
||||||
|
return cast(
|
||||||
|
float,
|
||||||
|
np.mean(distances)
|
||||||
|
+ self.breakpoint_threshold_amount * np.std(distances),
|
||||||
|
)
|
||||||
|
elif self.breakpoint_threshold_type == "interquartile":
|
||||||
|
q1, q3 = np.percentile(distances, [25, 75])
|
||||||
|
iqr = q3 - q1
|
||||||
|
|
||||||
|
return np.mean(distances) + self.breakpoint_threshold_amount * iqr
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got unexpected `breakpoint_threshold_type`: "
|
||||||
|
f"{self.breakpoint_threshold_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _threshold_from_clusters(self, distances: List[float]) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the threshold based on the number of chunks.
|
||||||
|
Inverse of percentile method.
|
||||||
|
"""
|
||||||
|
if self.number_of_chunks is None:
|
||||||
|
raise ValueError(
|
||||||
|
"This should never be called if `number_of_chunks` is None."
|
||||||
|
)
|
||||||
|
x1, y1 = len(distances), 0.0
|
||||||
|
x2, y2 = 1.0, 100.0
|
||||||
|
|
||||||
|
x = max(min(self.number_of_chunks, x1), x2)
|
||||||
|
|
||||||
|
# Linear interpolation formula
|
||||||
|
y = y1 + ((y2 - y1) / (x2 - x1)) * (x - x1)
|
||||||
|
y = min(max(y, 0), 100)
|
||||||
|
|
||||||
|
return cast(float, np.percentile(distances, y))
|
||||||
|
|
||||||
|
def _calculate_sentence_distances(
|
||||||
|
self, single_sentences_list: List[str]
|
||||||
|
) -> Tuple[List[float], List[dict]]:
|
||||||
"""Split text into multiple components."""
|
"""Split text into multiple components."""
|
||||||
|
|
||||||
|
_sentences = [
|
||||||
|
{"sentence": x, "index": i} for i, x in enumerate(single_sentences_list)
|
||||||
|
]
|
||||||
|
sentences = combine_sentences(_sentences)
|
||||||
|
embeddings = self.embeddings.embed_documents(
|
||||||
|
[x["combined_sentence"] for x in sentences]
|
||||||
|
)
|
||||||
|
for i, sentence in enumerate(sentences):
|
||||||
|
sentence["combined_sentence_embedding"] = embeddings[i]
|
||||||
|
|
||||||
|
return calculate_cosine_distances(sentences)
|
||||||
|
|
||||||
|
def split_text(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
) -> List[str]:
|
||||||
# Splitting the essay on '.', '?', and '!'
|
# Splitting the essay on '.', '?', and '!'
|
||||||
single_sentences_list = re.split(r"(?<=[.?!])\s+", text)
|
single_sentences_list = re.split(r"(?<=[.?!])\s+", text)
|
||||||
|
|
||||||
@ -108,29 +193,20 @@ class SemanticChunker(BaseDocumentTransformer):
|
|||||||
# np.percentile to fail.
|
# np.percentile to fail.
|
||||||
if len(single_sentences_list) == 1:
|
if len(single_sentences_list) == 1:
|
||||||
return single_sentences_list
|
return single_sentences_list
|
||||||
|
distances, sentences = self._calculate_sentence_distances(single_sentences_list)
|
||||||
sentences = [
|
if self.number_of_chunks is not None:
|
||||||
{"sentence": x, "index": i} for i, x in enumerate(single_sentences_list)
|
breakpoint_distance_threshold = self._threshold_from_clusters(distances)
|
||||||
]
|
else:
|
||||||
sentences = combine_sentences(sentences)
|
breakpoint_distance_threshold = self._calculate_breakpoint_threshold(
|
||||||
embeddings = self.embeddings.embed_documents(
|
distances
|
||||||
[x["combined_sentence"] for x in sentences]
|
)
|
||||||
)
|
|
||||||
for i, sentence in enumerate(sentences):
|
|
||||||
sentence["combined_sentence_embedding"] = embeddings[i]
|
|
||||||
distances, sentences = calculate_cosine_distances(sentences)
|
|
||||||
start_index = 0
|
|
||||||
|
|
||||||
# Create a list to hold the grouped sentences
|
|
||||||
chunks = []
|
|
||||||
breakpoint_percentile_threshold = 95
|
|
||||||
breakpoint_distance_threshold = np.percentile(
|
|
||||||
distances, breakpoint_percentile_threshold
|
|
||||||
) # If you want more chunks, lower the percentile cutoff
|
|
||||||
|
|
||||||
indices_above_thresh = [
|
indices_above_thresh = [
|
||||||
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
|
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
|
||||||
] # The indices of those breakpoints on your list
|
]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
start_index = 0
|
||||||
|
|
||||||
# Iterate through the breakpoints to slice the sentences
|
# Iterate through the breakpoints to slice the sentences
|
||||||
for index in indices_above_thresh:
|
for index in indices_above_thresh:
|
||||||
|
Loading…
Reference in New Issue
Block a user