mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 07:50:47 +00:00
mongodb: Add Hybrid and Full-Text Search Retrievers, release 0.2.0 (#25057)
## Description This pull-request extends the existing vector search strategies of MongoDBAtlasVectorSearch to include Hybrid (Reciprocal Rank Fusion) and Full-text via new Retrievers. There is a small breaking change in the form of the `prefilter` kwarg to search. For this, and because we have now added a great deal of features, including programmatic Index creation/deletion since 0.1.0, we plan to bump the version to 0.2.0. ### Checklist * Unit tests have been extended * formatting has been applied * One mypy error remains which will either go away in CI or be simplified. --------- Signed-off-by: Casey Clements <casey.clements@mongodb.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Mapping, Optional, cast
|
||||
from time import monotonic, sleep
|
||||
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Union, cast
|
||||
|
||||
from bson import ObjectId
|
||||
from langchain_core.callbacks.manager import (
|
||||
@@ -20,8 +21,47 @@ from langchain_core.pydantic_v1 import validator
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.results import DeleteResult, InsertManyResult
|
||||
|
||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||
from langchain_mongodb.cache import MongoDBAtlasSemanticCache
|
||||
|
||||
TIMEOUT = 120
|
||||
INTERVAL = 0.5
|
||||
|
||||
|
||||
class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch):
|
||||
def bulk_embed_and_insert_texts(
|
||||
self,
|
||||
texts: Union[List[str], Iterable[str]],
|
||||
metadatas: Union[List[dict], Generator[dict, Any, Any]],
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List:
|
||||
"""Patched insert_texts that waits for data to be indexed before returning"""
|
||||
ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids)
|
||||
start = monotonic()
|
||||
while len(ids_inserted) != len(self.similarity_search("sandwich")) and (
|
||||
monotonic() - start <= TIMEOUT
|
||||
):
|
||||
sleep(INTERVAL)
|
||||
return ids_inserted
|
||||
|
||||
def create_vector_search_index(
|
||||
self,
|
||||
dimensions: int,
|
||||
filters: Optional[List[str]] = None,
|
||||
update: bool = False,
|
||||
) -> None:
|
||||
result = super().create_vector_search_index(
|
||||
dimensions=dimensions, filters=filters, update=update
|
||||
)
|
||||
start = monotonic()
|
||||
while monotonic() - start <= TIMEOUT:
|
||||
if indexes := list(
|
||||
self._collection.list_search_indexes(name=self._index_name)
|
||||
):
|
||||
if indexes[0].get("status") == "READY":
|
||||
return result
|
||||
sleep(INTERVAL)
|
||||
|
||||
|
||||
class ConsistentFakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
@@ -147,13 +187,13 @@ class MockCollection(Collection):
|
||||
_aggregate_result: List[Any]
|
||||
_insert_result: Optional[InsertManyResult]
|
||||
_data: List[Any]
|
||||
_simluate_cache_aggregation_query: bool
|
||||
_simulate_cache_aggregation_query: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data = []
|
||||
self._aggregate_result = []
|
||||
self._insert_result = None
|
||||
self._simluate_cache_aggregation_query = False
|
||||
self._simulate_cache_aggregation_query = False
|
||||
|
||||
def delete_many(self, *args, **kwargs) -> DeleteResult: # type: ignore
|
||||
old_len = len(self._data)
|
||||
@@ -201,7 +241,7 @@ class MockCollection(Collection):
|
||||
elif upsert:
|
||||
self._data.append({**find_query, **set_options})
|
||||
|
||||
def _execute_cache_aggreation_query(self, *args, **kwargs) -> List[Dict[str, Any]]: # type: ignore
|
||||
def _execute_cache_aggregation_query(self, *args, **kwargs) -> List[Dict[str, Any]]: # type: ignore
|
||||
"""Helper function only to be used for MongoDBAtlasSemanticCache Testing
|
||||
|
||||
Returns:
|
||||
@@ -223,12 +263,12 @@ class MockCollection(Collection):
|
||||
return acc
|
||||
|
||||
def aggregate(self, *args, **kwargs) -> List[Any]: # type: ignore
|
||||
if self._simluate_cache_aggregation_query:
|
||||
return deepcopy(self._execute_cache_aggreation_query(*args, **kwargs))
|
||||
if self._simulate_cache_aggregation_query:
|
||||
return deepcopy(self._execute_cache_aggregation_query(*args, **kwargs))
|
||||
return deepcopy(self._aggregate_result)
|
||||
|
||||
def count_documents(self, *args, **kwargs) -> int: # type: ignore
|
||||
return len(self._data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "FakeCollection"
|
||||
return "MockCollection"
|
||||
|
Reference in New Issue
Block a user