mongodb: Add Hybrid and Full-Text Search Retrievers, release 0.2.0 (#25057)

## Description

This pull-request extends the existing vector search strategies of
MongoDBAtlasVectorSearch to include Hybrid (Reciprocal Rank Fusion) and
Full-text via new Retrievers.

There is a small breaking change in the form of the `prefilter` kwarg to
search. For this, and because we have now added a great deal of
features, including programmatic Index creation/deletion since 0.1.0, we
plan to bump the version to 0.2.0.

### Checklist
* Unit tests have been extended
* formatting has been applied
* One mypy error remains which will either go away in CI or be
simplified.

---------

Signed-off-by: Casey Clements <casey.clements@mongodb.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Casey Clements
2024-08-07 16:10:29 -04:00
committed by GitHub
parent f337408b0f
commit 6e9a8b188f
22 changed files with 1749 additions and 508 deletions

View File

@@ -1,8 +1,16 @@
"""
Integrate your operational database and vector search in a single, unified,
fully managed platform with full vector database capabilities on MongoDB Atlas.
Store your operational data, metadata, and vector embeddings in oue VectorStore,
MongoDBAtlasVectorSearch.
Insert into a Chain via a Vector, FullText, or Hybrid Retriever.
"""
from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
from langchain_mongodb.vectorstores import (
MongoDBAtlasVectorSearch,
)
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
__all__ = [
"MongoDBAtlasVectorSearch",

View File

@@ -1,10 +1,4 @@
"""
LangChain MongoDB Caches
Functions "_loads_generations" and "_dumps_generations"
are duplicated in this utility from modules:
- "libs/community/langchain_community/cache.py"
"""
"""LangChain MongoDB Caches."""
import json
import logging
@@ -27,100 +21,6 @@ from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
logger = logging.getLogger(__file__)
def _generate_mongo_client(connection_string: str) -> MongoClient:
return MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
"""
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
Args:
generations (RETURN_VAL_TYPE): A list of language model generations.
Returns:
str: a single string representing a list of generations.
This function (+ its counterpart `_loads_generations`) rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
Each item in the list can be `dumps`ed to a string,
then we make the whole list of strings into a json-dumped.
"""
return json.dumps([dumps(_item) for _item in generations])
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
"""
Deserialization of a string into a generic RETURN_VAL_TYPE
(i.e. a sequence of `Generation`).
See `_dumps_generations`, the inverse of this function.
Args:
generations_str (str): A string representing a list of generations.
Compatible with the legacy cache-blob format
Does not raise exceptions for malformed entries, just logs a warning
and returns none: the caller should be prepared for such a cache miss.
Returns:
RETURN_VAL_TYPE: A list of generations.
"""
try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
return generations
except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt
pass
try:
gen_dicts = json.loads(generations_str)
# not relying on `_load_generations_from_json` (which could disappear):
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
logger.warning(
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
)
return generations
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
)
return None
def _wait_until(
predicate: Callable, success_description: Any, timeout: float = 10.0
) -> None:
"""Wait up to 10 seconds (by default) for predicate to be true.
E.g.:
wait_until(lambda: client.primary == ('a', 1),
'connect to the primary')
If the lambda-expression isn't true after 10 seconds, we raise
AssertionError("Didn't ever connect to the primary").
Returns the predicate's first true value.
"""
start = time.time()
interval = min(float(timeout) / 100, 0.1)
while True:
retval = predicate()
if retval:
return retval
if time.time() - start > timeout:
raise TimeoutError("Didn't ever %s" % success_description)
time.sleep(interval)
class MongoDBCache(BaseCache):
"""MongoDB Atlas cache
@@ -216,7 +116,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
collection_name: str = "default",
database_name: str = "default",
index_name: str = "default",
wait_until_ready: bool = False,
wait_until_ready: Optional[float] = None,
score_threshold: Optional[float] = None,
**kwargs: Dict[str, Any],
):
@@ -233,8 +133,8 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
Defaults to "default".
index_name: Name of the Atlas Search index.
defaults to 'default'
wait_until_ready (bool): Block until MongoDB Atlas finishes indexing
the stored text. Hard timeout of 10 seconds. Defaults to False.
wait_until_ready (float): Wait this time for Atlas to finish indexing
the stored text. Defaults to None.
"""
client = _generate_mongo_client(connection_string)
self.collection = client[database_name][collection_name]
@@ -272,7 +172,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
prompt: str,
llm_string: str,
return_val: RETURN_VAL_TYPE,
wait_until_ready: Optional[bool] = None,
wait_until_ready: Optional[float] = None,
) -> None:
"""Update cache based on prompt and llm_string."""
self.add_texts(
@@ -290,7 +190,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
return self.lookup(prompt, llm_string) == return_val
if wait:
_wait_until(is_indexed, return_val)
_wait_until(is_indexed, return_val, timeout=wait)
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments.
@@ -302,3 +202,107 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
self.clear(llm_string="fake-model")
"""
self.collection.delete_many({**kwargs})
def _generate_mongo_client(connection_string: str) -> MongoClient:
return MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
"""
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
Args:
generations (RETURN_VAL_TYPE): A list of language model generations.
Returns:
str: a single string representing a list of generations.
This, and "_dumps_generations" are duplicated in this utility
from modules: "libs/community/langchain_community/cache.py"
This function and its counterpart rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
Each item in the list can be `dumps`ed to a string,
then we make the whole list of strings into a json-dumped.
"""
return json.dumps([dumps(_item) for _item in generations])
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
"""
Deserialization of a string into a generic RETURN_VAL_TYPE
(i.e. a sequence of `Generation`).
Args:
generations_str (str): A string representing a list of generations.
Returns:
RETURN_VAL_TYPE: A list of generations.
This function and its counterpart rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
See `_dumps_generations`, the inverse of this function.
Compatible with the legacy cache-blob format
Does not raise exceptions for malformed entries, just logs a warning
and returns none: the caller should be prepared for such a cache miss.
"""
try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
return generations
except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt
pass
try:
gen_dicts = json.loads(generations_str)
# not relying on `_load_generations_from_json` (which could disappear):
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
logger.warning(
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
)
return generations
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
)
return None
def _wait_until(
predicate: Callable, success_description: Any, timeout: float = 10.0
) -> None:
"""Wait up to 10 seconds (by default) for predicate to be true.
E.g.:
wait_until(lambda: client.primary == ('a', 1),
'connect to the primary')
If the lambda-expression isn't true after 10 seconds, we raise
AssertionError("Didn't ever connect to the primary").
Returns the predicate's first true value.
"""
start = time.time()
interval = min(float(timeout) / 100, 0.1)
while True:
retval = predicate()
if retval:
return retval
if time.time() - start > timeout:
raise TimeoutError("Didn't ever %s" % success_description)
time.sleep(interval)

View File

@@ -1,3 +1,5 @@
"""Search Index Commands"""
import logging
from time import monotonic, sleep
from typing import Any, Callable, Dict, List, Optional
@@ -8,8 +10,6 @@ from pymongo.operations import SearchIndexModel
logger = logging.getLogger(__file__)
_DELAY = 0.5 # Interval between checks for index operations
def _search_index_error_message() -> str:
return (
@@ -25,19 +25,24 @@ def _vector_search_index_definition(
dimensions: int,
path: str,
similarity: str,
filters: Optional[List[Dict[str, str]]],
filters: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return {
"fields": [
{
"numDimensions": dimensions,
"path": path,
"similarity": similarity,
"type": "vector",
},
*(filters or []),
]
}
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/
fields = [
{
"numDimensions": dimensions,
"path": path,
"similarity": similarity,
"type": "vector",
},
]
if filters:
for field in filters:
fields.append({"type": "filter", "path": field})
definition = {"fields": fields}
definition.update(kwargs)
return definition
def create_vector_search_index(
@@ -46,9 +51,10 @@ def create_vector_search_index(
dimensions: int,
path: str,
similarity: str,
filters: Optional[List[Dict[str, str]]] = None,
filters: Optional[List[str]] = None,
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Experimental Utility function to create a vector search index
@@ -58,9 +64,10 @@ def create_vector_search_index(
dimensions (int): Number of dimensions in embedding
path (str): field with vector embedding
similarity (str): The similarity score used for the index
filters (List[Dict[str, str]]): additional filters for index definition.
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""
logger.info("Creating Search Index %s on %s", index_name, collection.name)
@@ -72,6 +79,7 @@ def create_vector_search_index(
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
name=index_name,
type="vectorSearch",
@@ -83,7 +91,7 @@ def create_vector_search_index(
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"Index {index_name} creation did not finish in {wait_until_complete}!",
err=f"{index_name=} did not complete in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info(result)
@@ -127,9 +135,10 @@ def update_vector_search_index(
dimensions: int,
path: str,
similarity: str,
filters: List[Dict[str, str]],
filters: Optional[List[str]] = None,
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Update a search index.
@@ -138,12 +147,13 @@ def update_vector_search_index(
Args:
collection (Collection): MongoDB Collection
index_name (str): Name of Index
dimensions (int): Number of dimensions in embedding.
path (str): field with vector embedding.
dimensions (int): Number of dimensions in embedding
path (str): field with vector embedding
similarity (str): The similarity score used for the index.
filters (List[Dict[str, str]]): additional filters for index definition.
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""
logger.info(
@@ -157,6 +167,7 @@ def update_vector_search_index(
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
)
except OperationFailure as e:
@@ -201,7 +212,7 @@ def _wait_for_predicate(
Args:
predicate (Callable[, bool]): A function that returns a boolean value
err (str): Error message to raise if nothing occurs
timeout (float, optional): wait time for predicate. Defaults to TIMEOUT.
timeout (float, optional): Wait time for predicate. Defaults to TIMEOUT.
interval (float, optional): Interval to check predicate. Defaults to DELAY.
Raises:
@@ -212,3 +223,48 @@ def _wait_for_predicate(
if monotonic() - start > timeout:
raise TimeoutError(err)
sleep(interval)
def create_fulltext_search_index(
collection: Collection,
index_name: str,
field: str,
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Experimental Utility function to create an Atlas Search index
Args:
collection (Collection): MongoDB Collection
index_name (str): Name of Index
field (str): Field to index
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""
logger.info("Creating Search Index %s on %s", index_name, collection.name)
definition = {
"mappings": {"dynamic": False, "fields": {field: [{"type": "string"}]}}
}
try:
result = collection.create_search_index(
SearchIndexModel(
definition=definition,
name=index_name,
type="search",
**kwargs,
)
)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"{index_name=} did not complete in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info(result)

View File

@@ -0,0 +1,160 @@
"""Aggregation pipeline components used in Atlas Full-Text, Vector, and Hybrid Search
See the following for more:
- `Full-Text Search <https://www.mongodb.com/docs/atlas/atlas-search/aggregation-stages/search/#mongodb-pipeline-pipe.-search>`_
- `MongoDB Operators <https://www.mongodb.com/docs/atlas/atlas-search/operators-and-collectors/#std-label-operators-ref>`_
- `Vector Search <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/>`_
- `Filter Example <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
"""
from typing import Any, Dict, List, Optional
def text_search_stage(
query: str,
search_field: str,
index_name: str,
limit: Optional[int] = None,
filter: Optional[Dict[str, Any]] = None,
include_scores: Optional[bool] = True,
**kwargs: Any,
) -> List[Dict[str, Any]]: # noqa: E501
"""Full-Text search using Lucene's standard (BM25) analyzer
Args:
query: Input text to search for
search_field: Field in Collection that will be searched
index_name: Atlas Search Index name
limit: Maximum number of documents to return. Default of no limit
filter: Any MQL match expression comparing an indexed field
include_scores: Scores provide measure of relative relevance
Returns:
Dictionary defining the $search stage
"""
pipeline = [
{
"$search": {
"index": index_name,
"text": {"query": query, "path": search_field},
}
}
]
if filter:
pipeline.append({"$match": filter}) # type: ignore
if include_scores:
pipeline.append({"$set": {"score": {"$meta": "searchScore"}}})
if limit:
pipeline.append({"$limit": limit}) # type: ignore
return pipeline # type: ignore
def vector_search_stage(
query_vector: List[float],
search_field: str,
index_name: str,
top_k: int = 4,
filter: Optional[Dict[str, Any]] = None,
oversampling_factor: int = 10,
**kwargs: Any,
) -> Dict[str, Any]: # noqa: E501
"""Vector Search Stage without Scores.
Scoring is applied later depending on strategy.
vector search includes a vectorSearchScore that is typically used.
hybrid uses Reciprocal Rank Fusion.
Args:
query_vector: List of embedding vector
search_field: Field in Collection containing embedding vectors
index_name: Name of Atlas Vector Search Index tied to Collection
top_k: Number of documents to return
oversampling_factor: this times limit is the number of candidates
filter: MQL match expression comparing an indexed field.
Some operators are not supported.
See `vectorSearch filter docs <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
Returns:
Dictionary defining the $vectorSearch
"""
stage = {
"index": index_name,
"path": search_field,
"queryVector": query_vector,
"numCandidates": top_k * oversampling_factor,
"limit": top_k,
}
if filter:
stage["filter"] = filter
return {"$vectorSearch": stage}
def combine_pipelines(
pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str
) -> None:
"""Combines two aggregations into a single result set in-place."""
if pipeline:
pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}})
else:
pipeline.extend(stage)
def reciprocal_rank_stage(
score_field: str, penalty: float = 0, **kwargs: Any
) -> List[Dict[str, Any]]:
"""Stage adds Reciprocal Rank Fusion weighting.
First, it pushes documents retrieved from previous stage
into a temporary sub-document. It then unwinds to establish
the rank to each and applies the penalty.
Args:
score_field: A unique string to identify the search being ranked
penalty: A non-negative float.
extra_fields: Any fields other than text_field that one wishes to keep.
Returns:
RRF score := \frac{1}{rank + penalty} with rank in [1,2,..,n]
"""
rrf_pipeline = [
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
{
"$addFields": {
f"docs.{score_field}": {
"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]
},
"docs.rank": "$rank",
"_id": "$docs._id",
}
},
{"$replaceRoot": {"newRoot": "$docs"}},
]
return rrf_pipeline # type: ignore
def final_hybrid_stage(
scores_fields: List[str], limit: int, **kwargs: Any
) -> List[Dict[str, Any]]:
"""Sum weighted scores, sort, and apply limit.
Args:
scores_fields: List of fields given to scores of vector and text searches
limit: Number of documents to return
Returns:
Final aggregation stages
"""
return [
{"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}},
{"$replaceRoot": {"newRoot": "$docs"}},
{"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}},
{"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}},
{"$sort": {"score": -1}},
{"$limit": limit},
]

View File

@@ -0,0 +1,15 @@
"""Search Retrievers of various types.
Use ``MongoDBAtlasVectorSearch.as_retriever(**)``
to create MongoDB's core Vector Search Retriever.
"""
from langchain_mongodb.retrievers.full_text_search import (
MongoDBAtlasFullTextSearchRetriever,
)
from langchain_mongodb.retrievers.hybrid_search import MongoDBAtlasHybridSearchRetriever
__all__ = [
"MongoDBAtlasHybridSearchRetriever",
"MongoDBAtlasFullTextSearchRetriever",
]

View File

@@ -0,0 +1,59 @@
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pymongo.collection import Collection
from langchain_mongodb.pipelines import text_search_stage
from langchain_mongodb.utils import make_serializable
class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
"""Hybrid Search Retriever performs full-text searches
using Lucene's standard (BM25) analyzer.
"""
collection: Collection
"""MongoDB Collection on an Atlas cluster"""
search_index_name: str
"""Atlas Search Index name"""
search_field: str
"""Collection field that contains the text to be searched. It must be indexed"""
top_k: Optional[int] = None
"""Number of documents to return. Default is no limit"""
filter: Optional[Dict[str, Any]] = None
"""(Optional) List of MQL match expression comparing an indexed field"""
show_embeddings: float = False
"""If true, returned Document metadata will include vectors"""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve documents that are highest scoring / most similar to query.
Args:
query: String to find relevant documents for
run_manager: The callback handler to use
Returns:
List of relevant documents
"""
pipeline = text_search_stage( # type: ignore
query=query,
search_field=self.search_field,
index_name=self.search_index_name,
limit=self.top_k,
filter=self.filter,
)
# Execution
cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
# Formatting
docs = []
for res in cursor:
text = res.pop(self.search_field)
make_serializable(res)
docs.append(Document(page_content=text, metadata=res))
return docs

View File

@@ -0,0 +1,126 @@
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pymongo.collection import Collection
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.pipelines import (
combine_pipelines,
final_hybrid_stage,
reciprocal_rank_stage,
text_search_stage,
vector_search_stage,
)
from langchain_mongodb.utils import make_serializable
class MongoDBAtlasHybridSearchRetriever(BaseRetriever):
"""Hybrid Search Retriever combines vector and full-text searches
weighting them the via Reciprocal Rank Fusion (RRF) algorithm.
Increasing the vector_penalty will reduce the importance on the vector search.
Increasing the fulltext_penalty will correspondingly reduce the fulltext score.
For more on the algorithm,see
https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking
"""
vectorstore: MongoDBAtlasVectorSearch
"""MongoDBAtlas VectorStore"""
search_index_name: str
"""Atlas Search Index (full-text) name"""
top_k: int = 4
"""Number of documents to return."""
oversampling_factor: int = 10
"""This times top_k is the number of candidates chosen at each step"""
pre_filter: Optional[Dict[str, Any]] = None
"""(Optional) Any MQL match expression comparing an indexed field"""
post_filter: Optional[List[Dict[str, Any]]] = None
"""(Optional) Pipeline of MongoDB aggregation stages for postprocessing."""
vector_penalty: float = 60.0
"""Penalty applied to vector search results in RRF: scores=1/(rank + penalty)"""
fulltext_penalty: float = 60.0
"""Penalty applied to full-text search results in RRF: scores=1/(rank + penalty)"""
show_embeddings: float = False
"""If true, returned Document metadata will include vectors."""
@property
def collection(self) -> Collection:
return self.vectorstore._collection
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve documents that are highest scoring / most similar to query.
Note that the same query is used in both searches,
embedded for vector search, and as-is for full-text search.
Args:
query: String to find relevant documents for
run_manager: The callback handler to use
Returns:
List of relevant documents
"""
query_vector = self.vectorstore._embedding.embed_query(query)
scores_fields = ["vector_score", "fulltext_score"]
pipeline: List[Any] = []
# First we build up the aggregation pipeline,
# then it is passed to the server to execute
# Vector Search stage
vector_pipeline = [
vector_search_stage(
query_vector=query_vector,
search_field=self.vectorstore._embedding_key,
index_name=self.vectorstore._index_name,
top_k=self.top_k,
filter=self.pre_filter,
oversampling_factor=self.oversampling_factor,
)
]
vector_pipeline += reciprocal_rank_stage("vector_score", self.vector_penalty)
combine_pipelines(pipeline, vector_pipeline, self.collection.name)
# Full-Text Search stage
text_pipeline = text_search_stage(
query=query,
search_field=self.vectorstore._text_key,
index_name=self.search_index_name,
limit=self.top_k,
filter=self.pre_filter,
)
text_pipeline.extend(
reciprocal_rank_stage("fulltext_score", self.fulltext_penalty)
)
combine_pipelines(pipeline, text_pipeline, self.collection.name)
# Sum and sort stage
pipeline.extend(
final_hybrid_stage(scores_fields=scores_fields, limit=self.top_k)
)
# Removal of embeddings unless requested.
if not self.show_embeddings:
pipeline.append({"$project": {self.vectorstore._embedding_key: 0}})
# Post filtering
if self.post_filter is not None:
pipeline.extend(self.post_filter)
# Execution
cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
# Formatting
docs = []
for res in cursor:
text = res.pop(self.vectorstore._text_key)
# score = res.pop("score") # The score remains buried!
make_serializable(res)
docs.append(Document(page_content=text, metadata=res))
return docs

View File

@@ -1,6 +1,13 @@
"""
Tools for the Maximal Marginal Relevance (MMR) reranking.
Duplicated from langchain_community to avoid cross-dependencies.
"""Various Utility Functions
- Tools for handling bson.ObjectId
The help IDs live as ObjectId in MongoDB and str in Langchain and JSON.
- Tools for the Maximal Marginal Relevance (MMR) reranking
These are duplicated from langchain_community to avoid cross-dependencies.
Functions "maximal_marginal_relevance" and "cosine_similarity"
are duplicated in this utility respectively from modules:
@@ -21,11 +28,6 @@ logger = logging.getLogger(__name__)
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
class FailCode:
INDEX_NOT_FOUND = 27
INDEX_ALREADY_EXISTS = 68
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"""Row-wise cosine similarity between two equal-width matrices."""
if len(X) == 0 or len(Y) == 0:
@@ -65,7 +67,37 @@ def maximal_marginal_relevance(
lambda_mult: float = 0.5,
k: int = 4,
) -> List[int]:
"""Calculate maximal marginal relevance."""
"""Compute Maximal Marginal Relevance (MMR).
MMR is a technique used to select documents that are both relevant to the query
and diverse among themselves. This function returns the indices
of the top-k embeddings that maximize the marginal relevance.
Args:
query_embedding (np.ndarray): The embedding vector of the query.
embedding_list (list of np.ndarray): A list containing the embedding vectors
of the candidate documents.
lambda_mult (float, optional): The trade-off parameter between
relevance and diversity. Defaults to 0.5.
k (int, optional): The number of embeddings to select. Defaults to 4.
Returns:
list of int: The indices of the embeddings that maximize the marginal relevance.
Notes:
The Maximal Marginal Relevance (MMR) is computed using the following formula:
MMR = argmax_{D_i ∈ R \ S} [λ * Sim(D_i, Q) - (1 - λ) * max_{D_j ∈ S} Sim(D_i, D_j)]
where:
- R is the set of candidate documents,
- S is the set of selected documents,
- Q is the query embedding,
- Sim(D_i, Q) is the similarity between document D_i and the query,
- Sim(D_i, D_j) is the similarity between documents D_i and D_j,
- λ is the trade-off parameter.
"""
if min(k, len(embedding_list)) <= 0:
return []
if query_embedding.ndim == 1:
@@ -137,6 +169,7 @@ def make_serializable(
obj: Dict[str, Any],
) -> None:
"""Recursively cast values in a dict to a form able to json.dump"""
from bson import ObjectId
for k, v in obj.items():

View File

@@ -29,6 +29,7 @@ from langchain_mongodb.index import (
create_vector_search_index,
update_vector_search_index,
)
from langchain_mongodb.pipelines import vector_search_stage
from langchain_mongodb.utils import (
make_serializable,
maximal_marginal_relevance,
@@ -36,7 +37,6 @@ from langchain_mongodb.utils import (
str_to_oid,
)
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
VST = TypeVar("VST", bound=VectorStore)
logger = logging.getLogger(__name__)
@@ -45,19 +45,40 @@ DEFAULT_INSERT_BATCH_SIZE = 100_000
class MongoDBAtlasVectorSearch(VectorStore):
"""MongoDBAtlas vector store integration.
"""MongoDB Atlas vector store integration.
MongoDBAtlasVectorSearch performs data operations on
text, embeddings and arbitrary data. In addition to CRUD operations,
the VectorStore provides Vector Search
based on similarity of embedding vectors following the
Hierarchical Navigable Small Worlds (HNSW) algorithm.
This supports a number of models to ascertain scores,
"similarity" (default), "MMR", and "similarity_score_threshold".
These are described in the search_type argument to as_retriever,
which provides the Runnable.invoke(query) API, allowing
MongoDBAtlasVectorSearch to be used within a chain.
Setup:
Install ``langchain-mongodb`` and ``pymongo`` and setup a MongoDB Atlas cluster (read through [this guide](https://www.mongodb.com/docs/manual/reference/connection-string/) to do so).
* Set up a MongoDB Atlas cluster. The free tier M0 will allow you to start.
Search Indexes are only available on Atlas, the fully managed cloud service,
not the self-managed MongoDB.
Follow [this guide](https://www.mongodb.com/basics/mongodb-atlas-tutorial)
* Create a Collection and a Vector Search Index.The procedure is described
[here](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure).
* Install ``langchain-mongodb``
.. code-block:: bash
pip install -qU langchain-mongodb pymongo
.. code-block:: python
import getpass
MONGODB_ATLAS_CLUSTER_URI = getpass.getpass("MongoDB Atlas Cluster URI:")
Key init args — indexing params:
@@ -127,7 +148,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
Search with filter:
.. code-block:: python
results = vector_store.similarity_search(query="thud",k=1,filter={"bar": "baz"})
results = vector_store.similarity_search(query="thud",k=1,post_filter=[{"bar": "baz"]})
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
@@ -184,29 +205,24 @@ class MongoDBAtlasVectorSearch(VectorStore):
def __init__(
self,
collection: Collection[MongoDBDocumentType],
collection: Collection[Dict[str, Any]],
embedding: Embeddings,
*,
index_name: str = "default",
index_name: str = "vector_index",
text_key: str = "text",
embedding_key: str = "embedding",
relevance_score_fn: str = "cosine",
**kwargs: Any,
):
"""
Args:
collection: MongoDB collection to add the texts to.
embedding: Text embedding model to use.
text_key: MongoDB field that will contain the text for each
document.
defaults to 'text'
embedding_key: MongoDB field that will contain the embedding for
each document.
defaults to 'embedding'
index_name: Name of the Atlas Search index.
defaults to 'default'
relevance_score_fn: The similarity score used for the index.
defaults to 'cosine'
Currently supported: 'euclidean', 'cosine', and 'dotProduct'.
collection: MongoDB collection to add the texts to
embedding: Text embedding model to use
text_key: MongoDB field that will contain the text for each document
index_name: Existing Atlas Vector Search Index
embedding_key: Field that will contain the embedding for each document
vector_index_name: Name of the Atlas Vector Search index
relevance_score_fn: The similarity score used for the index
Currently supported: 'euclidean', 'cosine', and 'dotProduct'
"""
self._collection = collection
self._embedding = embedding
@@ -412,69 +428,32 @@ class MongoDBAtlasVectorSearch(VectorStore):
start = end
return result_ids
def _similarity_search_with_score(
self,
embedding: List[float],
k: int = 4,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
include_embedding: bool = False,
include_ids: bool = False,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Core implementation."""
params = {
"queryVector": embedding,
"path": self._embedding_key,
"numCandidates": k * 10,
"limit": k,
"index": self._index_name,
}
if pre_filter:
params["filter"] = pre_filter
query = {"$vectorSearch": params}
pipeline = [
query,
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
# Exclude the embedding key from the return payload
if not include_embedding:
pipeline.append({"$project": {self._embedding_key: 0}})
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
docs = []
for res in cursor:
text = res.pop(self._text_key)
score = res.pop("score")
make_serializable(res)
docs.append((Document(page_content=text, metadata=res), score))
return docs
def similarity_search_with_score(
self,
query: str,
k: int = 4,
pre_filter: Optional[Dict] = None,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
oversampling_factor: int = 10,
include_embeddings: bool = False,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
) -> List[Tuple[Document, float]]: # noqa: E501
"""Return MongoDB documents most similar to the given query and their scores.
Uses the vectorSearch operator available in MongoDB Atlas Search.
For more: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/
Atlas Vector Search eliminates the need to run a separate
search system alongside your database.
Args:
query: Text to look up documents similar to.
k: (Optional) number of documents to return. Defaults to 4.
pre_filter: (Optional) dictionary of argument(s) to prefilter document
fields on.
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
following the vectorSearch stage.
Args:
query: Input text of semantic query
k: Number of documents to return. Also known as top_k.
pre_filter: List of MQL match expressions comparing an indexed field
post_filter_pipeline: (Optional) Arbitrary pipeline of MongoDB
aggregation stages applied after the search is complete.
oversampling_factor: This times k is the number of candidates chosen
at each step in the in HNSW Vector Search
include_embeddings: If True, the embedding vector of each result
will be included in metadata.
kwargs: Additional arguments are specific to the search_type
Returns:
List of documents most similar to the query and their scores.
@@ -485,6 +464,8 @@ class MongoDBAtlasVectorSearch(VectorStore):
k=k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
oversampling_factor=oversampling_factor,
include_embeddings=include_embeddings,
**kwargs,
)
return docs
@@ -493,36 +474,46 @@ class MongoDBAtlasVectorSearch(VectorStore):
self,
query: str,
k: int = 4,
pre_filter: Optional[Dict] = None,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
oversampling_factor: int = 10,
include_scores: bool = False,
include_embeddings: bool = False,
**kwargs: Any,
) -> List[Document]:
) -> List[Document]: # noqa: E501
"""Return MongoDB documents most similar to the given query.
Uses the vectorSearch operator available in MongoDB Atlas Search.
For more: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/
Atlas Vector Search eliminates the need to run a separate
search system alongside your database.
Args:
query: Text to look up documents similar to.
Args:
query: Input text of semantic query
k: (Optional) number of documents to return. Defaults to 4.
pre_filter: (Optional) dictionary of argument(s) to prefilter document
fields on.
pre_filter: List of MQL match expressions comparing an indexed field
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
following the vectorSearch stage.
to filter/process results after $vectorSearch.
oversampling_factor: Multiple of k used when generating number of candidates
at each step in the HNSW Vector Search,
include_scores: If True, the query score of each result
will be included in metadata.
include_embeddings: If True, the embedding vector of each result
will be included in metadata.
kwargs: Additional arguments are specific to the search_type
Returns:
List of documents most similar to the query and their scores.
"""
additional = kwargs.get("additional")
docs_and_scores = self.similarity_search_with_score(
query,
k=k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
oversampling_factor=oversampling_factor,
include_embeddings=include_embeddings,
**kwargs,
)
if additional and "similarity_score" in additional:
if include_scores:
for doc, score in docs_and_scores:
doc.metadata["score"] = score
return [doc for doc, _ in docs_and_scores]
@@ -533,7 +524,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
pre_filter: Optional[Dict] = None,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[Document]:
@@ -548,19 +539,16 @@ class MongoDBAtlasVectorSearch(VectorStore):
fetch_k: (Optional) number of documents to fetch before passing to MMR
algorithm. Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
pre_filter: (Optional) dictionary of argument(s) to prefilter on document
fields.
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
pre_filter: List of MQL match expressions comparing an indexed field
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
following the vectorSearch stage.
following the $vectorSearch stage.
Returns:
List of documents selected by maximal marginal relevance.
"""
query_embedding = self._embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding=query_embedding,
embedding=self._embedding.embed_query(query),
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
@@ -575,7 +563,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[Dict]] = None,
collection: Optional[Collection[MongoDBDocumentType]] = None,
collection: Optional[Collection] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> MongoDBAtlasVectorSearch:
@@ -588,6 +576,9 @@ class MongoDBAtlasVectorSearch(VectorStore):
This is intended to be a quick way to get started.
See `MongoDBAtlasVectorSearch` for kwargs and further description.
Example:
.. code-block:: python
from pymongo import MongoClient
@@ -649,8 +640,9 @@ class MongoDBAtlasVectorSearch(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
pre_filter: Optional[Dict] = None,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
oversampling_factor: int = 10,
**kwargs: Any,
) -> List[Document]: # type: ignore
"""Return docs selected using the maximal marginal relevance.
@@ -666,10 +658,13 @@ class MongoDBAtlasVectorSearch(VectorStore):
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
pre_filter: (Optional) dictionary of argument(s) to prefilter on document
fields.
pre_filter: (Optional) dictionary of arguments to filter document fields on.
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
following the vectorSearch stage.
oversampling_factor: Multiple of k used when generating number
of candidates in HNSW Vector Search,
kwargs: Additional arguments are specific to the search_type
Returns:
List of Documents selected by maximal marginal relevance.
"""
@@ -678,7 +673,8 @@ class MongoDBAtlasVectorSearch(VectorStore):
k=fetch_k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
include_embedding=kwargs.pop("include_embedding", True),
include_embeddings=True,
oversampling_factor=oversampling_factor,
**kwargs,
)
mmr_doc_indexes = maximal_marginal_relevance(
@@ -696,31 +692,82 @@ class MongoDBAtlasVectorSearch(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
oversampling_factor: int = 10,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance."""
return await run_in_executor(
None,
self.max_marginal_relevance_search_by_vector,
self.max_marginal_relevance_search_by_vector, # type: ignore[arg-type]
embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
oversampling_factor=oversampling_factor,
**kwargs,
)
def _similarity_search_with_score(
self,
query_vector: List[float],
k: int = 4,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
oversampling_factor: int = 10,
include_embeddings: bool = False,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Core search routine. See external methods for details."""
# Atlas Vector Search, potentially with filter
pipeline = [
vector_search_stage(
query_vector,
self._embedding_key,
self._index_name,
k,
pre_filter,
oversampling_factor,
**kwargs,
),
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
# Remove embeddings unless requested.
if not include_embeddings:
pipeline.append({"$project": {self._embedding_key: 0}})
# Post-processing
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
# Execution
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
docs = []
# Format
for res in cursor:
text = res.pop(self._text_key)
score = res.pop("score")
make_serializable(res)
docs.append((Document(page_content=text, metadata=res), score))
return docs
def create_vector_search_index(
self,
dimensions: int,
filters: Optional[List[Dict[str, str]]] = None,
filters: Optional[List[str]] = None,
update: bool = False,
) -> None:
"""Creates a MongoDB Atlas vectorSearch index for the VectorStore
Note**: This method may fail as it requires a MongoDB Atlas with
these pre-requisites:
- M10 cluster or higher
- https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#prerequisites
Note**: This method may fail as it requires a MongoDB Atlas with these
`pre-requisites <https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#prerequisites>`.
Currently, vector and full-text search index operations need to be
performed manually on the Atlas UI for shared M0 clusters.
Args:
dimensions (int): Number of dimensions in embedding