diff --git a/libs/experimental/langchain_experimental/text_splitter.py b/libs/experimental/langchain_experimental/text_splitter.py index 135f28346e3..c5b6ed513af 100644 --- a/libs/experimental/langchain_experimental/text_splitter.py +++ b/libs/experimental/langchain_experimental/text_splitter.py @@ -106,6 +106,7 @@ class SemanticChunker(BaseDocumentTransformer): def __init__( self, embeddings: Embeddings, + buffer_size: int = 1, add_start_index: bool = False, breakpoint_threshold_type: BreakpointThresholdType = "percentile", breakpoint_threshold_amount: Optional[float] = None, @@ -113,6 +114,7 @@ class SemanticChunker(BaseDocumentTransformer): ): self._add_start_index = add_start_index self.embeddings = embeddings + self.buffer_size = buffer_size self.breakpoint_threshold_type = breakpoint_threshold_type self.number_of_chunks = number_of_chunks if breakpoint_threshold_amount is None: @@ -173,7 +175,7 @@ class SemanticChunker(BaseDocumentTransformer): _sentences = [ {"sentence": x, "index": i} for i, x in enumerate(single_sentences_list) ] - sentences = combine_sentences(_sentences) + sentences = combine_sentences(_sentences, self.buffer_size) embeddings = self.embeddings.embed_documents( [x["combined_sentence"] for x in sentences] )