experimental: add buffer_size hyperparameter to SemanticChunker as in source video (#19208)

add buffer_size hyperparameter which used in combine_sentences function
This commit is contained in:
Cycle 2024-03-19 06:54:20 +03:00 committed by GitHub
parent ae3c7f702c
commit 77868b1974
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -106,6 +106,7 @@ class SemanticChunker(BaseDocumentTransformer):
def __init__( def __init__(
self, self,
embeddings: Embeddings, embeddings: Embeddings,
buffer_size: int = 1,
add_start_index: bool = False, add_start_index: bool = False,
breakpoint_threshold_type: BreakpointThresholdType = "percentile", breakpoint_threshold_type: BreakpointThresholdType = "percentile",
breakpoint_threshold_amount: Optional[float] = None, breakpoint_threshold_amount: Optional[float] = None,
@ -113,6 +114,7 @@ class SemanticChunker(BaseDocumentTransformer):
): ):
self._add_start_index = add_start_index self._add_start_index = add_start_index
self.embeddings = embeddings self.embeddings = embeddings
self.buffer_size = buffer_size
self.breakpoint_threshold_type = breakpoint_threshold_type self.breakpoint_threshold_type = breakpoint_threshold_type
self.number_of_chunks = number_of_chunks self.number_of_chunks = number_of_chunks
if breakpoint_threshold_amount is None: if breakpoint_threshold_amount is None:
@ -173,7 +175,7 @@ class SemanticChunker(BaseDocumentTransformer):
_sentences = [ _sentences = [
{"sentence": x, "index": i} for i, x in enumerate(single_sentences_list) {"sentence": x, "index": i} for i, x in enumerate(single_sentences_list)
] ]
sentences = combine_sentences(_sentences) sentences = combine_sentences(_sentences, self.buffer_size)
embeddings = self.embeddings.embed_documents( embeddings = self.embeddings.embed_documents(
[x["combined_sentence"] for x in sentences] [x["combined_sentence"] for x in sentences]
) )