mongodb: Add Hybrid and Full-Text Search Retrievers, release 0.2.0 (#25057)

## Description

This pull-request extends the existing vector search strategies of
MongoDBAtlasVectorSearch to include Hybrid (Reciprocal Rank Fusion) and
Full-text via new Retrievers.

There is a small breaking change in the form of the `prefilter` kwarg to
search. For this, and because we have now added a great deal of
features, including programmatic Index creation/deletion since 0.1.0, we
plan to bump the version to 0.2.0.

### Checklist
* Unit tests have been extended
* formatting has been applied
* One mypy error remains which will either go away in CI or be
simplified.

---------

Signed-off-by: Casey Clements <casey.clements@mongodb.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Casey Clements
2024-08-07 16:10:29 -04:00
committed by GitHub
parent f337408b0f
commit 6e9a8b188f
22 changed files with 1749 additions and 508 deletions

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Optional, cast
from time import monotonic, sleep
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Union, cast
from bson import ObjectId
from langchain_core.callbacks.manager import (
@@ -20,8 +21,47 @@ from langchain_core.pydantic_v1 import validator
from pymongo.collection import Collection
from pymongo.results import DeleteResult, InsertManyResult
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.cache import MongoDBAtlasSemanticCache
TIMEOUT = 120
INTERVAL = 0.5
class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch):
def bulk_embed_and_insert_texts(
self,
texts: Union[List[str], Iterable[str]],
metadatas: Union[List[dict], Generator[dict, Any, Any]],
ids: Optional[List[str]] = None,
) -> List:
"""Patched insert_texts that waits for data to be indexed before returning"""
ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids)
start = monotonic()
while len(ids_inserted) != len(self.similarity_search("sandwich")) and (
monotonic() - start <= TIMEOUT
):
sleep(INTERVAL)
return ids_inserted
def create_vector_search_index(
self,
dimensions: int,
filters: Optional[List[str]] = None,
update: bool = False,
) -> None:
result = super().create_vector_search_index(
dimensions=dimensions, filters=filters, update=update
)
start = monotonic()
while monotonic() - start <= TIMEOUT:
if indexes := list(
self._collection.list_search_indexes(name=self._index_name)
):
if indexes[0].get("status") == "READY":
return result
sleep(INTERVAL)
class ConsistentFakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
@@ -147,13 +187,13 @@ class MockCollection(Collection):
_aggregate_result: List[Any]
_insert_result: Optional[InsertManyResult]
_data: List[Any]
_simluate_cache_aggregation_query: bool
_simulate_cache_aggregation_query: bool
def __init__(self) -> None:
self._data = []
self._aggregate_result = []
self._insert_result = None
self._simluate_cache_aggregation_query = False
self._simulate_cache_aggregation_query = False
def delete_many(self, *args, **kwargs) -> DeleteResult: # type: ignore
old_len = len(self._data)
@@ -201,7 +241,7 @@ class MockCollection(Collection):
elif upsert:
self._data.append({**find_query, **set_options})
def _execute_cache_aggreation_query(self, *args, **kwargs) -> List[Dict[str, Any]]: # type: ignore
def _execute_cache_aggregation_query(self, *args, **kwargs) -> List[Dict[str, Any]]: # type: ignore
"""Helper function only to be used for MongoDBAtlasSemanticCache Testing
Returns:
@@ -223,12 +263,12 @@ class MockCollection(Collection):
return acc
def aggregate(self, *args, **kwargs) -> List[Any]: # type: ignore
if self._simluate_cache_aggregation_query:
return deepcopy(self._execute_cache_aggreation_query(*args, **kwargs))
if self._simulate_cache_aggregation_query:
return deepcopy(self._execute_cache_aggregation_query(*args, **kwargs))
return deepcopy(self._aggregate_result)
def count_documents(self, *args, **kwargs) -> int: # type: ignore
return len(self._data)
def __repr__(self) -> str:
return "FakeCollection"
return "MockCollection"