SemanticChunker : Feature Addition ("Semantic Splitting with gradient") (#22895)

```SemanticChunker``` currently provide three methods to split the texts semantically:
- percentile
- standard_deviation
- interquartile

I propose new method ```gradient```. In this method, the gradient of distance is used to split chunks along with the percentile method (technically) . This method is useful when chunks are highly correlated with each other or specific to a domain e.g. legal or medical. The idea is to apply anomaly detection on gradient array so that the distribution become wider and easy to identify boundaries in highly semantic data.
I have tested this merge on a set of 10 domain specific documents (mostly legal).

Details : 
    - **Issue:** Improvement
    - **Dependencies:** NA
    - **Twitter handle:** [x.com/prajapat_ravi](https://x.com/prajapat_ravi)


@hwchase17

---------

Co-authored-by: Raviraj Prajapat <raviraj.prajapat@sirionlabs.com>
Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
Raviraj
2024-06-18 09:31:08 +05:30
committed by GitHub
parent 55705c0f5e
commit 858ce264ef
2 changed files with 82 additions and 10 deletions

View File

@@ -84,11 +84,14 @@ def calculate_cosine_distances(sentences: List[dict]) -> Tuple[List[float], List
return distances, sentences
BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"]
BreakpointThresholdType = Literal[
"percentile", "standard_deviation", "interquartile", "gradient"
]
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = {
"percentile": 95,
"standard_deviation": 3,
"interquartile": 1.5,
"gradient": 95,
}
@@ -127,23 +130,34 @@ class SemanticChunker(BaseDocumentTransformer):
else:
self.breakpoint_threshold_amount = breakpoint_threshold_amount
def _calculate_breakpoint_threshold(self, distances: List[float]) -> float:
def _calculate_breakpoint_threshold(
self, distances: List[float]
) -> Tuple[float, List[float]]:
if self.breakpoint_threshold_type == "percentile":
return cast(
float,
np.percentile(distances, self.breakpoint_threshold_amount),
)
), distances
elif self.breakpoint_threshold_type == "standard_deviation":
return cast(
float,
np.mean(distances)
+ self.breakpoint_threshold_amount * np.std(distances),
)
), distances
elif self.breakpoint_threshold_type == "interquartile":
q1, q3 = np.percentile(distances, [25, 75])
iqr = q3 - q1
return np.mean(distances) + self.breakpoint_threshold_amount * iqr
return np.mean(
distances
) + self.breakpoint_threshold_amount * iqr, distances
elif self.breakpoint_threshold_type == "gradient":
# Calculate the threshold based on the distribution of gradient of distance array. # noqa: E501
distance_gradient = np.gradient(distances, range(0, len(distances)))
return cast(
float,
np.percentile(distance_gradient, self.breakpoint_threshold_amount),
), distance_gradient
else:
raise ValueError(
f"Got unexpected `breakpoint_threshold_type`: "
@@ -201,13 +215,17 @@ class SemanticChunker(BaseDocumentTransformer):
distances, sentences = self._calculate_sentence_distances(single_sentences_list)
if self.number_of_chunks is not None:
breakpoint_distance_threshold = self._threshold_from_clusters(distances)
breakpoint_array = distances
else:
breakpoint_distance_threshold = self._calculate_breakpoint_threshold(
distances
)
(
breakpoint_distance_threshold,
breakpoint_array,
) = self._calculate_breakpoint_threshold(distances)
indices_above_thresh = [
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
i
for i, x in enumerate(breakpoint_array)
if x > breakpoint_distance_threshold
]
chunks = []