mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
SemanticChunker : Feature Addition ("Semantic Splitting with gradient") (#22895)
```SemanticChunker``` currently provide three methods to split the texts semantically: - percentile - standard_deviation - interquartile I propose new method ```gradient```. In this method, the gradient of distance is used to split chunks along with the percentile method (technically) . This method is useful when chunks are highly correlated with each other or specific to a domain e.g. legal or medical. The idea is to apply anomaly detection on gradient array so that the distribution become wider and easy to identify boundaries in highly semantic data. I have tested this merge on a set of 10 domain specific documents (mostly legal). Details : - **Issue:** Improvement - **Dependencies:** NA - **Twitter handle:** [x.com/prajapat_ravi](https://x.com/prajapat_ravi) @hwchase17 --------- Co-authored-by: Raviraj Prajapat <raviraj.prajapat@sirionlabs.com> Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
parent
55705c0f5e
commit
858ce264ef
@ -297,13 +297,67 @@
|
||||
"print(len(docs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Gradient\n",
|
||||
"\n",
|
||||
"In this method, the gradient of distance is used to split chunks along with the percentile method.\n",
|
||||
"This method is useful when chunks are highly correlated with each other or specific to a domain e.g. legal or medical. The idea is to apply anomaly detection on gradient array so that the distribution become wider and easy to identify boundaries in highly semantic data."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"id": "423c6e099e94ca69"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1f65472",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"text_splitter = SemanticChunker(\n",
|
||||
" OpenAIEmbeddings(), breakpoint_threshold_type=\"gradient\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Madam Speaker, Madam Vice President, our First Lady and Second Gentleman.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"docs = text_splitter.create_documents([state_of_the_union])\n",
|
||||
"print(docs[0].page_content)"
|
||||
],
|
||||
"metadata": {},
|
||||
"id": "e9f393d316ce1f6c"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"26\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(len(docs))"
|
||||
],
|
||||
"metadata": {},
|
||||
"id": "a407cd57f02a0db4"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -84,11 +84,14 @@ def calculate_cosine_distances(sentences: List[dict]) -> Tuple[List[float], List
|
||||
return distances, sentences
|
||||
|
||||
|
||||
BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"]
|
||||
BreakpointThresholdType = Literal[
|
||||
"percentile", "standard_deviation", "interquartile", "gradient"
|
||||
]
|
||||
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = {
|
||||
"percentile": 95,
|
||||
"standard_deviation": 3,
|
||||
"interquartile": 1.5,
|
||||
"gradient": 95,
|
||||
}
|
||||
|
||||
|
||||
@ -127,23 +130,34 @@ class SemanticChunker(BaseDocumentTransformer):
|
||||
else:
|
||||
self.breakpoint_threshold_amount = breakpoint_threshold_amount
|
||||
|
||||
def _calculate_breakpoint_threshold(self, distances: List[float]) -> float:
|
||||
def _calculate_breakpoint_threshold(
|
||||
self, distances: List[float]
|
||||
) -> Tuple[float, List[float]]:
|
||||
if self.breakpoint_threshold_type == "percentile":
|
||||
return cast(
|
||||
float,
|
||||
np.percentile(distances, self.breakpoint_threshold_amount),
|
||||
)
|
||||
), distances
|
||||
elif self.breakpoint_threshold_type == "standard_deviation":
|
||||
return cast(
|
||||
float,
|
||||
np.mean(distances)
|
||||
+ self.breakpoint_threshold_amount * np.std(distances),
|
||||
)
|
||||
), distances
|
||||
elif self.breakpoint_threshold_type == "interquartile":
|
||||
q1, q3 = np.percentile(distances, [25, 75])
|
||||
iqr = q3 - q1
|
||||
|
||||
return np.mean(distances) + self.breakpoint_threshold_amount * iqr
|
||||
return np.mean(
|
||||
distances
|
||||
) + self.breakpoint_threshold_amount * iqr, distances
|
||||
elif self.breakpoint_threshold_type == "gradient":
|
||||
# Calculate the threshold based on the distribution of gradient of distance array. # noqa: E501
|
||||
distance_gradient = np.gradient(distances, range(0, len(distances)))
|
||||
return cast(
|
||||
float,
|
||||
np.percentile(distance_gradient, self.breakpoint_threshold_amount),
|
||||
), distance_gradient
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected `breakpoint_threshold_type`: "
|
||||
@ -201,13 +215,17 @@ class SemanticChunker(BaseDocumentTransformer):
|
||||
distances, sentences = self._calculate_sentence_distances(single_sentences_list)
|
||||
if self.number_of_chunks is not None:
|
||||
breakpoint_distance_threshold = self._threshold_from_clusters(distances)
|
||||
breakpoint_array = distances
|
||||
else:
|
||||
breakpoint_distance_threshold = self._calculate_breakpoint_threshold(
|
||||
distances
|
||||
)
|
||||
(
|
||||
breakpoint_distance_threshold,
|
||||
breakpoint_array,
|
||||
) = self._calculate_breakpoint_threshold(distances)
|
||||
|
||||
indices_above_thresh = [
|
||||
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
|
||||
i
|
||||
for i, x in enumerate(breakpoint_array)
|
||||
if x > breakpoint_distance_threshold
|
||||
]
|
||||
|
||||
chunks = []
|
||||
|
Loading…
Reference in New Issue
Block a user