diff --git a/docs/docs/how_to/semantic-chunker.ipynb b/docs/docs/how_to/semantic-chunker.ipynb index fbf50f3f57d..ff952b2d817 100644 --- a/docs/docs/how_to/semantic-chunker.ipynb +++ b/docs/docs/how_to/semantic-chunker.ipynb @@ -297,13 +297,67 @@ "print(len(docs))" ] }, + { + "cell_type": "markdown", + "source": [ + "### Gradient\n", + "\n", + "In this method, the gradient of distance is used to split chunks along with the percentile method.\n", + "This method is useful when chunks are highly correlated with each other or specific to a domain e.g. legal or medical. The idea is to apply anomaly detection on gradient array so that the distribution become wider and easy to identify boundaries in highly semantic data." + ], + "metadata": { + "collapsed": false + }, + "id": "423c6e099e94ca69" + }, { "cell_type": "code", "execution_count": null, "id": "b1f65472", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "text_splitter = SemanticChunker(\n", + " OpenAIEmbeddings(), breakpoint_threshold_type=\"gradient\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Madam Speaker, Madam Vice President, our First Lady and Second Gentleman.\n" + ] + } + ], + "source": [ + "docs = text_splitter.create_documents([state_of_the_union])\n", + "print(docs[0].page_content)" + ], + "metadata": {}, + "id": "e9f393d316ce1f6c" + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "26\n" + ] + } + ], + "source": [ + "print(len(docs))" + ], + "metadata": {}, + "id": "a407cd57f02a0db4" } ], "metadata": { diff --git a/libs/experimental/langchain_experimental/text_splitter.py b/libs/experimental/langchain_experimental/text_splitter.py index 1d0c462fa2a..4ef5796f59f 100644 --- a/libs/experimental/langchain_experimental/text_splitter.py +++ b/libs/experimental/langchain_experimental/text_splitter.py @@ -84,11 +84,14 @@ def calculate_cosine_distances(sentences: List[dict]) -> Tuple[List[float], List return distances, sentences -BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"] +BreakpointThresholdType = Literal[ + "percentile", "standard_deviation", "interquartile", "gradient" +] BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = { "percentile": 95, "standard_deviation": 3, "interquartile": 1.5, + "gradient": 95, } @@ -127,23 +130,34 @@ class SemanticChunker(BaseDocumentTransformer): else: self.breakpoint_threshold_amount = breakpoint_threshold_amount - def _calculate_breakpoint_threshold(self, distances: List[float]) -> float: + def _calculate_breakpoint_threshold( + self, distances: List[float] + ) -> Tuple[float, List[float]]: if self.breakpoint_threshold_type == "percentile": return cast( float, np.percentile(distances, self.breakpoint_threshold_amount), - ) + ), distances elif self.breakpoint_threshold_type == "standard_deviation": return cast( float, np.mean(distances) + self.breakpoint_threshold_amount * np.std(distances), - ) + ), distances elif self.breakpoint_threshold_type == "interquartile": q1, q3 = np.percentile(distances, [25, 75]) iqr = q3 - q1 - return np.mean(distances) + self.breakpoint_threshold_amount * iqr + return np.mean( + distances + ) + self.breakpoint_threshold_amount * iqr, distances + elif self.breakpoint_threshold_type == "gradient": + # Calculate the threshold based on the distribution of gradient of distance array. # noqa: E501 + distance_gradient = np.gradient(distances, range(0, len(distances))) + return cast( + float, + np.percentile(distance_gradient, self.breakpoint_threshold_amount), + ), distance_gradient else: raise ValueError( f"Got unexpected `breakpoint_threshold_type`: " @@ -201,13 +215,17 @@ class SemanticChunker(BaseDocumentTransformer): distances, sentences = self._calculate_sentence_distances(single_sentences_list) if self.number_of_chunks is not None: breakpoint_distance_threshold = self._threshold_from_clusters(distances) + breakpoint_array = distances else: - breakpoint_distance_threshold = self._calculate_breakpoint_threshold( - distances - ) + ( + breakpoint_distance_threshold, + breakpoint_array, + ) = self._calculate_breakpoint_threshold(distances) indices_above_thresh = [ - i for i, x in enumerate(distances) if x > breakpoint_distance_threshold + i + for i, x in enumerate(breakpoint_array) + if x > breakpoint_distance_threshold ] chunks = []