mongodb: [performance] Increase DEFAULT_INSERT_BATCH_SIZE to 100,000 and introduce sizing constraints (#19608)

This commit is contained in:
Jib 2024-05-14 18:11:26 -04:00 committed by GitHub
parent e69a9bedf8
commit f369495fa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,7 +32,7 @@ VST = TypeVar("VST", bound=VectorStore)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_INSERT_BATCH_SIZE = 100 DEFAULT_INSERT_BATCH_SIZE = 100_000
class MongoDBAtlasVectorSearch(VectorStore): class MongoDBAtlasVectorSearch(VectorStore):
@ -151,18 +151,24 @@ class MongoDBAtlasVectorSearch(VectorStore):
""" """
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE) batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts) _metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
texts_batch = [] texts_batch = texts
metadatas_batch = [] metadatas_batch = _metadatas
result_ids = [] result_ids = []
for i, (text, metadata) in enumerate(zip(texts, _metadatas)): if batch_size:
texts_batch.append(text) texts_batch = []
metadatas_batch.append(metadata) metadatas_batch = []
if (i + 1) % batch_size == 0: size = 0
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) for i, (text, metadata) in enumerate(zip(texts, _metadatas)):
texts_batch = [] size += len(text) + len(metadata)
metadatas_batch = [] texts_batch.append(text)
metadatas_batch.append(metadata)
if (i + 1) % batch_size == 0 or size >= 47_000_000:
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
texts_batch = []
metadatas_batch = []
size = 0
if texts_batch: if texts_batch:
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) # type: ignore
return result_ids return result_ids
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List: def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List: