mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
SemanticChunker : Feature Addition ("Semantic Splitting with gradient") (#22895)
```SemanticChunker``` currently provide three methods to split the texts semantically: - percentile - standard_deviation - interquartile I propose new method ```gradient```. In this method, the gradient of distance is used to split chunks along with the percentile method (technically) . This method is useful when chunks are highly correlated with each other or specific to a domain e.g. legal or medical. The idea is to apply anomaly detection on gradient array so that the distribution become wider and easy to identify boundaries in highly semantic data. I have tested this merge on a set of 10 domain specific documents (mostly legal). Details : - **Issue:** Improvement - **Dependencies:** NA - **Twitter handle:** [x.com/prajapat_ravi](https://x.com/prajapat_ravi) @hwchase17 --------- Co-authored-by: Raviraj Prajapat <raviraj.prajapat@sirionlabs.com> Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
parent
55705c0f5e
commit
858ce264ef
@ -297,13 +297,67 @@
|
|||||||
"print(len(docs))"
|
"print(len(docs))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Gradient\n",
|
||||||
|
"\n",
|
||||||
|
"In this method, the gradient of distance is used to split chunks along with the percentile method.\n",
|
||||||
|
"This method is useful when chunks are highly correlated with each other or specific to a domain e.g. legal or medical. The idea is to apply anomaly detection on gradient array so that the distribution become wider and easy to identify boundaries in highly semantic data."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "423c6e099e94ca69"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "b1f65472",
|
"id": "b1f65472",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": [
|
||||||
|
"text_splitter = SemanticChunker(\n",
|
||||||
|
" OpenAIEmbeddings(), breakpoint_threshold_type=\"gradient\"\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Madam Speaker, Madam Vice President, our First Lady and Second Gentleman.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"docs = text_splitter.create_documents([state_of_the_union])\n",
|
||||||
|
"print(docs[0].page_content)"
|
||||||
|
],
|
||||||
|
"metadata": {},
|
||||||
|
"id": "e9f393d316ce1f6c"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"26\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(len(docs))"
|
||||||
|
],
|
||||||
|
"metadata": {},
|
||||||
|
"id": "a407cd57f02a0db4"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -84,11 +84,14 @@ def calculate_cosine_distances(sentences: List[dict]) -> Tuple[List[float], List
|
|||||||
return distances, sentences
|
return distances, sentences
|
||||||
|
|
||||||
|
|
||||||
BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"]
|
BreakpointThresholdType = Literal[
|
||||||
|
"percentile", "standard_deviation", "interquartile", "gradient"
|
||||||
|
]
|
||||||
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = {
|
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = {
|
||||||
"percentile": 95,
|
"percentile": 95,
|
||||||
"standard_deviation": 3,
|
"standard_deviation": 3,
|
||||||
"interquartile": 1.5,
|
"interquartile": 1.5,
|
||||||
|
"gradient": 95,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -127,23 +130,34 @@ class SemanticChunker(BaseDocumentTransformer):
|
|||||||
else:
|
else:
|
||||||
self.breakpoint_threshold_amount = breakpoint_threshold_amount
|
self.breakpoint_threshold_amount = breakpoint_threshold_amount
|
||||||
|
|
||||||
def _calculate_breakpoint_threshold(self, distances: List[float]) -> float:
|
def _calculate_breakpoint_threshold(
|
||||||
|
self, distances: List[float]
|
||||||
|
) -> Tuple[float, List[float]]:
|
||||||
if self.breakpoint_threshold_type == "percentile":
|
if self.breakpoint_threshold_type == "percentile":
|
||||||
return cast(
|
return cast(
|
||||||
float,
|
float,
|
||||||
np.percentile(distances, self.breakpoint_threshold_amount),
|
np.percentile(distances, self.breakpoint_threshold_amount),
|
||||||
)
|
), distances
|
||||||
elif self.breakpoint_threshold_type == "standard_deviation":
|
elif self.breakpoint_threshold_type == "standard_deviation":
|
||||||
return cast(
|
return cast(
|
||||||
float,
|
float,
|
||||||
np.mean(distances)
|
np.mean(distances)
|
||||||
+ self.breakpoint_threshold_amount * np.std(distances),
|
+ self.breakpoint_threshold_amount * np.std(distances),
|
||||||
)
|
), distances
|
||||||
elif self.breakpoint_threshold_type == "interquartile":
|
elif self.breakpoint_threshold_type == "interquartile":
|
||||||
q1, q3 = np.percentile(distances, [25, 75])
|
q1, q3 = np.percentile(distances, [25, 75])
|
||||||
iqr = q3 - q1
|
iqr = q3 - q1
|
||||||
|
|
||||||
return np.mean(distances) + self.breakpoint_threshold_amount * iqr
|
return np.mean(
|
||||||
|
distances
|
||||||
|
) + self.breakpoint_threshold_amount * iqr, distances
|
||||||
|
elif self.breakpoint_threshold_type == "gradient":
|
||||||
|
# Calculate the threshold based on the distribution of gradient of distance array. # noqa: E501
|
||||||
|
distance_gradient = np.gradient(distances, range(0, len(distances)))
|
||||||
|
return cast(
|
||||||
|
float,
|
||||||
|
np.percentile(distance_gradient, self.breakpoint_threshold_amount),
|
||||||
|
), distance_gradient
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got unexpected `breakpoint_threshold_type`: "
|
f"Got unexpected `breakpoint_threshold_type`: "
|
||||||
@ -201,13 +215,17 @@ class SemanticChunker(BaseDocumentTransformer):
|
|||||||
distances, sentences = self._calculate_sentence_distances(single_sentences_list)
|
distances, sentences = self._calculate_sentence_distances(single_sentences_list)
|
||||||
if self.number_of_chunks is not None:
|
if self.number_of_chunks is not None:
|
||||||
breakpoint_distance_threshold = self._threshold_from_clusters(distances)
|
breakpoint_distance_threshold = self._threshold_from_clusters(distances)
|
||||||
|
breakpoint_array = distances
|
||||||
else:
|
else:
|
||||||
breakpoint_distance_threshold = self._calculate_breakpoint_threshold(
|
(
|
||||||
distances
|
breakpoint_distance_threshold,
|
||||||
)
|
breakpoint_array,
|
||||||
|
) = self._calculate_breakpoint_threshold(distances)
|
||||||
|
|
||||||
indices_above_thresh = [
|
indices_above_thresh = [
|
||||||
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
|
i
|
||||||
|
for i, x in enumerate(breakpoint_array)
|
||||||
|
if x > breakpoint_distance_threshold
|
||||||
]
|
]
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
|
Loading…
Reference in New Issue
Block a user