mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 22:42:05 +00:00
SemanticChunker : Feature Addition ("Semantic Splitting with gradient") (#22895)
```SemanticChunker``` currently provide three methods to split the texts semantically: - percentile - standard_deviation - interquartile I propose new method ```gradient```. In this method, the gradient of distance is used to split chunks along with the percentile method (technically) . This method is useful when chunks are highly correlated with each other or specific to a domain e.g. legal or medical. The idea is to apply anomaly detection on gradient array so that the distribution become wider and easy to identify boundaries in highly semantic data. I have tested this merge on a set of 10 domain specific documents (mostly legal). Details : - **Issue:** Improvement - **Dependencies:** NA - **Twitter handle:** [x.com/prajapat_ravi](https://x.com/prajapat_ravi) @hwchase17 --------- Co-authored-by: Raviraj Prajapat <raviraj.prajapat@sirionlabs.com> Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
@@ -84,11 +84,14 @@ def calculate_cosine_distances(sentences: List[dict]) -> Tuple[List[float], List
|
||||
return distances, sentences
|
||||
|
||||
|
||||
BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"]
|
||||
BreakpointThresholdType = Literal[
|
||||
"percentile", "standard_deviation", "interquartile", "gradient"
|
||||
]
|
||||
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = {
|
||||
"percentile": 95,
|
||||
"standard_deviation": 3,
|
||||
"interquartile": 1.5,
|
||||
"gradient": 95,
|
||||
}
|
||||
|
||||
|
||||
@@ -127,23 +130,34 @@ class SemanticChunker(BaseDocumentTransformer):
|
||||
else:
|
||||
self.breakpoint_threshold_amount = breakpoint_threshold_amount
|
||||
|
||||
def _calculate_breakpoint_threshold(self, distances: List[float]) -> float:
|
||||
def _calculate_breakpoint_threshold(
|
||||
self, distances: List[float]
|
||||
) -> Tuple[float, List[float]]:
|
||||
if self.breakpoint_threshold_type == "percentile":
|
||||
return cast(
|
||||
float,
|
||||
np.percentile(distances, self.breakpoint_threshold_amount),
|
||||
)
|
||||
), distances
|
||||
elif self.breakpoint_threshold_type == "standard_deviation":
|
||||
return cast(
|
||||
float,
|
||||
np.mean(distances)
|
||||
+ self.breakpoint_threshold_amount * np.std(distances),
|
||||
)
|
||||
), distances
|
||||
elif self.breakpoint_threshold_type == "interquartile":
|
||||
q1, q3 = np.percentile(distances, [25, 75])
|
||||
iqr = q3 - q1
|
||||
|
||||
return np.mean(distances) + self.breakpoint_threshold_amount * iqr
|
||||
return np.mean(
|
||||
distances
|
||||
) + self.breakpoint_threshold_amount * iqr, distances
|
||||
elif self.breakpoint_threshold_type == "gradient":
|
||||
# Calculate the threshold based on the distribution of gradient of distance array. # noqa: E501
|
||||
distance_gradient = np.gradient(distances, range(0, len(distances)))
|
||||
return cast(
|
||||
float,
|
||||
np.percentile(distance_gradient, self.breakpoint_threshold_amount),
|
||||
), distance_gradient
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected `breakpoint_threshold_type`: "
|
||||
@@ -201,13 +215,17 @@ class SemanticChunker(BaseDocumentTransformer):
|
||||
distances, sentences = self._calculate_sentence_distances(single_sentences_list)
|
||||
if self.number_of_chunks is not None:
|
||||
breakpoint_distance_threshold = self._threshold_from_clusters(distances)
|
||||
breakpoint_array = distances
|
||||
else:
|
||||
breakpoint_distance_threshold = self._calculate_breakpoint_threshold(
|
||||
distances
|
||||
)
|
||||
(
|
||||
breakpoint_distance_threshold,
|
||||
breakpoint_array,
|
||||
) = self._calculate_breakpoint_threshold(distances)
|
||||
|
||||
indices_above_thresh = [
|
||||
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
|
||||
i
|
||||
for i, x in enumerate(breakpoint_array)
|
||||
if x > breakpoint_distance_threshold
|
||||
]
|
||||
|
||||
chunks = []
|
||||
|
Reference in New Issue
Block a user