mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 23:13:31 +00:00
mongodb: Add Hybrid and Full-Text Search Retrievers, release 0.2.0 (#25057)
## Description This pull-request extends the existing vector search strategies of MongoDBAtlasVectorSearch to include Hybrid (Reciprocal Rank Fusion) and Full-text via new Retrievers. There is a small breaking change in the form of the `prefilter` kwarg to search. For this, and because we have now added a great deal of features, including programmatic Index creation/deletion since 0.1.0, we plan to bump the version to 0.2.0. ### Checklist * Unit tests have been extended * formatting has been applied * One mypy error remains which will either go away in CI or be simplified. --------- Signed-off-by: Casey Clements <casey.clements@mongodb.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,8 +1,16 @@
|
||||
"""
|
||||
Integrate your operational database and vector search in a single, unified,
|
||||
fully managed platform with full vector database capabilities on MongoDB Atlas.
|
||||
|
||||
|
||||
Store your operational data, metadata, and vector embeddings in oue VectorStore,
|
||||
MongoDBAtlasVectorSearch.
|
||||
Insert into a Chain via a Vector, FullText, or Hybrid Retriever.
|
||||
"""
|
||||
|
||||
from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
|
||||
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
|
||||
from langchain_mongodb.vectorstores import (
|
||||
MongoDBAtlasVectorSearch,
|
||||
)
|
||||
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
|
||||
|
||||
__all__ = [
|
||||
"MongoDBAtlasVectorSearch",
|
||||
|
@@ -1,10 +1,4 @@
|
||||
"""
|
||||
LangChain MongoDB Caches
|
||||
|
||||
Functions "_loads_generations" and "_dumps_generations"
|
||||
are duplicated in this utility from modules:
|
||||
- "libs/community/langchain_community/cache.py"
|
||||
"""
|
||||
"""LangChain MongoDB Caches."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
@@ -27,100 +21,6 @@ from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def _generate_mongo_client(connection_string: str) -> MongoClient:
|
||||
return MongoClient(
|
||||
connection_string,
|
||||
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
|
||||
)
|
||||
|
||||
|
||||
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
|
||||
"""
|
||||
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
|
||||
|
||||
Args:
|
||||
generations (RETURN_VAL_TYPE): A list of language model generations.
|
||||
|
||||
Returns:
|
||||
str: a single string representing a list of generations.
|
||||
|
||||
This function (+ its counterpart `_loads_generations`) rely on
|
||||
the dumps/loads pair with Reviver, so are able to deal
|
||||
with all subclasses of Generation.
|
||||
|
||||
Each item in the list can be `dumps`ed to a string,
|
||||
then we make the whole list of strings into a json-dumped.
|
||||
"""
|
||||
return json.dumps([dumps(_item) for _item in generations])
|
||||
|
||||
|
||||
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
|
||||
"""
|
||||
Deserialization of a string into a generic RETURN_VAL_TYPE
|
||||
(i.e. a sequence of `Generation`).
|
||||
|
||||
See `_dumps_generations`, the inverse of this function.
|
||||
|
||||
Args:
|
||||
generations_str (str): A string representing a list of generations.
|
||||
|
||||
Compatible with the legacy cache-blob format
|
||||
Does not raise exceptions for malformed entries, just logs a warning
|
||||
and returns none: the caller should be prepared for such a cache miss.
|
||||
|
||||
Returns:
|
||||
RETURN_VAL_TYPE: A list of generations.
|
||||
"""
|
||||
try:
|
||||
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
|
||||
return generations
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# deferring the (soft) handling to after the legacy-format attempt
|
||||
pass
|
||||
|
||||
try:
|
||||
gen_dicts = json.loads(generations_str)
|
||||
# not relying on `_load_generations_from_json` (which could disappear):
|
||||
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
|
||||
logger.warning(
|
||||
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
|
||||
)
|
||||
return generations
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _wait_until(
|
||||
predicate: Callable, success_description: Any, timeout: float = 10.0
|
||||
) -> None:
|
||||
"""Wait up to 10 seconds (by default) for predicate to be true.
|
||||
|
||||
E.g.:
|
||||
|
||||
wait_until(lambda: client.primary == ('a', 1),
|
||||
'connect to the primary')
|
||||
|
||||
If the lambda-expression isn't true after 10 seconds, we raise
|
||||
AssertionError("Didn't ever connect to the primary").
|
||||
|
||||
Returns the predicate's first true value.
|
||||
"""
|
||||
start = time.time()
|
||||
interval = min(float(timeout) / 100, 0.1)
|
||||
while True:
|
||||
retval = predicate()
|
||||
if retval:
|
||||
return retval
|
||||
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError("Didn't ever %s" % success_description)
|
||||
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
class MongoDBCache(BaseCache):
|
||||
"""MongoDB Atlas cache
|
||||
|
||||
@@ -216,7 +116,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
collection_name: str = "default",
|
||||
database_name: str = "default",
|
||||
index_name: str = "default",
|
||||
wait_until_ready: bool = False,
|
||||
wait_until_ready: Optional[float] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
):
|
||||
@@ -233,8 +133,8 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
Defaults to "default".
|
||||
index_name: Name of the Atlas Search index.
|
||||
defaults to 'default'
|
||||
wait_until_ready (bool): Block until MongoDB Atlas finishes indexing
|
||||
the stored text. Hard timeout of 10 seconds. Defaults to False.
|
||||
wait_until_ready (float): Wait this time for Atlas to finish indexing
|
||||
the stored text. Defaults to None.
|
||||
"""
|
||||
client = _generate_mongo_client(connection_string)
|
||||
self.collection = client[database_name][collection_name]
|
||||
@@ -272,7 +172,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
prompt: str,
|
||||
llm_string: str,
|
||||
return_val: RETURN_VAL_TYPE,
|
||||
wait_until_ready: Optional[bool] = None,
|
||||
wait_until_ready: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
self.add_texts(
|
||||
@@ -290,7 +190,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
return self.lookup(prompt, llm_string) == return_val
|
||||
|
||||
if wait:
|
||||
_wait_until(is_indexed, return_val)
|
||||
_wait_until(is_indexed, return_val, timeout=wait)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache that can take additional keyword arguments.
|
||||
@@ -302,3 +202,107 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
self.clear(llm_string="fake-model")
|
||||
"""
|
||||
self.collection.delete_many({**kwargs})
|
||||
|
||||
|
||||
def _generate_mongo_client(connection_string: str) -> MongoClient:
|
||||
return MongoClient(
|
||||
connection_string,
|
||||
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
|
||||
)
|
||||
|
||||
|
||||
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
|
||||
"""
|
||||
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
|
||||
|
||||
Args:
|
||||
generations (RETURN_VAL_TYPE): A list of language model generations.
|
||||
|
||||
Returns:
|
||||
str: a single string representing a list of generations.
|
||||
|
||||
This, and "_dumps_generations" are duplicated in this utility
|
||||
from modules: "libs/community/langchain_community/cache.py"
|
||||
|
||||
This function and its counterpart rely on
|
||||
the dumps/loads pair with Reviver, so are able to deal
|
||||
with all subclasses of Generation.
|
||||
|
||||
Each item in the list can be `dumps`ed to a string,
|
||||
then we make the whole list of strings into a json-dumped.
|
||||
"""
|
||||
return json.dumps([dumps(_item) for _item in generations])
|
||||
|
||||
|
||||
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
|
||||
"""
|
||||
Deserialization of a string into a generic RETURN_VAL_TYPE
|
||||
(i.e. a sequence of `Generation`).
|
||||
|
||||
Args:
|
||||
generations_str (str): A string representing a list of generations.
|
||||
|
||||
Returns:
|
||||
RETURN_VAL_TYPE: A list of generations.
|
||||
|
||||
|
||||
This function and its counterpart rely on
|
||||
the dumps/loads pair with Reviver, so are able to deal
|
||||
with all subclasses of Generation.
|
||||
|
||||
See `_dumps_generations`, the inverse of this function.
|
||||
|
||||
Compatible with the legacy cache-blob format
|
||||
Does not raise exceptions for malformed entries, just logs a warning
|
||||
and returns none: the caller should be prepared for such a cache miss.
|
||||
|
||||
|
||||
"""
|
||||
try:
|
||||
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
|
||||
return generations
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# deferring the (soft) handling to after the legacy-format attempt
|
||||
pass
|
||||
|
||||
try:
|
||||
gen_dicts = json.loads(generations_str)
|
||||
# not relying on `_load_generations_from_json` (which could disappear):
|
||||
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
|
||||
logger.warning(
|
||||
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
|
||||
)
|
||||
return generations
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _wait_until(
|
||||
predicate: Callable, success_description: Any, timeout: float = 10.0
|
||||
) -> None:
|
||||
"""Wait up to 10 seconds (by default) for predicate to be true.
|
||||
|
||||
E.g.:
|
||||
|
||||
wait_until(lambda: client.primary == ('a', 1),
|
||||
'connect to the primary')
|
||||
|
||||
If the lambda-expression isn't true after 10 seconds, we raise
|
||||
AssertionError("Didn't ever connect to the primary").
|
||||
|
||||
Returns the predicate's first true value.
|
||||
"""
|
||||
start = time.time()
|
||||
interval = min(float(timeout) / 100, 0.1)
|
||||
while True:
|
||||
retval = predicate()
|
||||
if retval:
|
||||
return retval
|
||||
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError("Didn't ever %s" % success_description)
|
||||
|
||||
time.sleep(interval)
|
||||
|
@@ -1,3 +1,5 @@
|
||||
"""Search Index Commands"""
|
||||
|
||||
import logging
|
||||
from time import monotonic, sleep
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
@@ -8,8 +10,6 @@ from pymongo.operations import SearchIndexModel
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
_DELAY = 0.5 # Interval between checks for index operations
|
||||
|
||||
|
||||
def _search_index_error_message() -> str:
|
||||
return (
|
||||
@@ -25,19 +25,24 @@ def _vector_search_index_definition(
|
||||
dimensions: int,
|
||||
path: str,
|
||||
similarity: str,
|
||||
filters: Optional[List[Dict[str, str]]],
|
||||
filters: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"fields": [
|
||||
{
|
||||
"numDimensions": dimensions,
|
||||
"path": path,
|
||||
"similarity": similarity,
|
||||
"type": "vector",
|
||||
},
|
||||
*(filters or []),
|
||||
]
|
||||
}
|
||||
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/
|
||||
fields = [
|
||||
{
|
||||
"numDimensions": dimensions,
|
||||
"path": path,
|
||||
"similarity": similarity,
|
||||
"type": "vector",
|
||||
},
|
||||
]
|
||||
if filters:
|
||||
for field in filters:
|
||||
fields.append({"type": "filter", "path": field})
|
||||
definition = {"fields": fields}
|
||||
definition.update(kwargs)
|
||||
return definition
|
||||
|
||||
|
||||
def create_vector_search_index(
|
||||
@@ -46,9 +51,10 @@ def create_vector_search_index(
|
||||
dimensions: int,
|
||||
path: str,
|
||||
similarity: str,
|
||||
filters: Optional[List[Dict[str, str]]] = None,
|
||||
filters: Optional[List[str]] = None,
|
||||
*,
|
||||
wait_until_complete: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Experimental Utility function to create a vector search index
|
||||
|
||||
@@ -58,9 +64,10 @@ def create_vector_search_index(
|
||||
dimensions (int): Number of dimensions in embedding
|
||||
path (str): field with vector embedding
|
||||
similarity (str): The similarity score used for the index
|
||||
filters (List[Dict[str, str]]): additional filters for index definition.
|
||||
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
|
||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
||||
until search index is ready.
|
||||
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
|
||||
"""
|
||||
logger.info("Creating Search Index %s on %s", index_name, collection.name)
|
||||
|
||||
@@ -72,6 +79,7 @@ def create_vector_search_index(
|
||||
path=path,
|
||||
similarity=similarity,
|
||||
filters=filters,
|
||||
**kwargs,
|
||||
),
|
||||
name=index_name,
|
||||
type="vectorSearch",
|
||||
@@ -83,7 +91,7 @@ def create_vector_search_index(
|
||||
if wait_until_complete:
|
||||
_wait_for_predicate(
|
||||
predicate=lambda: _is_index_ready(collection, index_name),
|
||||
err=f"Index {index_name} creation did not finish in {wait_until_complete}!",
|
||||
err=f"{index_name=} did not complete in {wait_until_complete}!",
|
||||
timeout=wait_until_complete,
|
||||
)
|
||||
logger.info(result)
|
||||
@@ -127,9 +135,10 @@ def update_vector_search_index(
|
||||
dimensions: int,
|
||||
path: str,
|
||||
similarity: str,
|
||||
filters: List[Dict[str, str]],
|
||||
filters: Optional[List[str]] = None,
|
||||
*,
|
||||
wait_until_complete: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Update a search index.
|
||||
|
||||
@@ -138,12 +147,13 @@ def update_vector_search_index(
|
||||
Args:
|
||||
collection (Collection): MongoDB Collection
|
||||
index_name (str): Name of Index
|
||||
dimensions (int): Number of dimensions in embedding.
|
||||
path (str): field with vector embedding.
|
||||
dimensions (int): Number of dimensions in embedding
|
||||
path (str): field with vector embedding
|
||||
similarity (str): The similarity score used for the index.
|
||||
filters (List[Dict[str, str]]): additional filters for index definition.
|
||||
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
|
||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
||||
until search index is ready.
|
||||
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
@@ -157,6 +167,7 @@ def update_vector_search_index(
|
||||
path=path,
|
||||
similarity=similarity,
|
||||
filters=filters,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
except OperationFailure as e:
|
||||
@@ -201,7 +212,7 @@ def _wait_for_predicate(
|
||||
Args:
|
||||
predicate (Callable[, bool]): A function that returns a boolean value
|
||||
err (str): Error message to raise if nothing occurs
|
||||
timeout (float, optional): wait time for predicate. Defaults to TIMEOUT.
|
||||
timeout (float, optional): Wait time for predicate. Defaults to TIMEOUT.
|
||||
interval (float, optional): Interval to check predicate. Defaults to DELAY.
|
||||
|
||||
Raises:
|
||||
@@ -212,3 +223,48 @@ def _wait_for_predicate(
|
||||
if monotonic() - start > timeout:
|
||||
raise TimeoutError(err)
|
||||
sleep(interval)
|
||||
|
||||
|
||||
def create_fulltext_search_index(
|
||||
collection: Collection,
|
||||
index_name: str,
|
||||
field: str,
|
||||
*,
|
||||
wait_until_complete: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Experimental Utility function to create an Atlas Search index
|
||||
|
||||
Args:
|
||||
collection (Collection): MongoDB Collection
|
||||
index_name (str): Name of Index
|
||||
field (str): Field to index
|
||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
||||
until search index is ready
|
||||
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
|
||||
"""
|
||||
logger.info("Creating Search Index %s on %s", index_name, collection.name)
|
||||
|
||||
definition = {
|
||||
"mappings": {"dynamic": False, "fields": {field: [{"type": "string"}]}}
|
||||
}
|
||||
|
||||
try:
|
||||
result = collection.create_search_index(
|
||||
SearchIndexModel(
|
||||
definition=definition,
|
||||
name=index_name,
|
||||
type="search",
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
except OperationFailure as e:
|
||||
raise OperationFailure(_search_index_error_message()) from e
|
||||
|
||||
if wait_until_complete:
|
||||
_wait_for_predicate(
|
||||
predicate=lambda: _is_index_ready(collection, index_name),
|
||||
err=f"{index_name=} did not complete in {wait_until_complete}!",
|
||||
timeout=wait_until_complete,
|
||||
)
|
||||
logger.info(result)
|
||||
|
160
libs/partners/mongodb/langchain_mongodb/pipelines.py
Normal file
160
libs/partners/mongodb/langchain_mongodb/pipelines.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Aggregation pipeline components used in Atlas Full-Text, Vector, and Hybrid Search
|
||||
|
||||
See the following for more:
|
||||
- `Full-Text Search <https://www.mongodb.com/docs/atlas/atlas-search/aggregation-stages/search/#mongodb-pipeline-pipe.-search>`_
|
||||
- `MongoDB Operators <https://www.mongodb.com/docs/atlas/atlas-search/operators-and-collectors/#std-label-operators-ref>`_
|
||||
- `Vector Search <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/>`_
|
||||
- `Filter Example <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
def text_search_stage(
|
||||
query: str,
|
||||
search_field: str,
|
||||
index_name: str,
|
||||
limit: Optional[int] = None,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
include_scores: Optional[bool] = True,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]: # noqa: E501
|
||||
"""Full-Text search using Lucene's standard (BM25) analyzer
|
||||
|
||||
Args:
|
||||
query: Input text to search for
|
||||
search_field: Field in Collection that will be searched
|
||||
index_name: Atlas Search Index name
|
||||
limit: Maximum number of documents to return. Default of no limit
|
||||
filter: Any MQL match expression comparing an indexed field
|
||||
include_scores: Scores provide measure of relative relevance
|
||||
|
||||
Returns:
|
||||
Dictionary defining the $search stage
|
||||
"""
|
||||
pipeline = [
|
||||
{
|
||||
"$search": {
|
||||
"index": index_name,
|
||||
"text": {"query": query, "path": search_field},
|
||||
}
|
||||
}
|
||||
]
|
||||
if filter:
|
||||
pipeline.append({"$match": filter}) # type: ignore
|
||||
if include_scores:
|
||||
pipeline.append({"$set": {"score": {"$meta": "searchScore"}}})
|
||||
if limit:
|
||||
pipeline.append({"$limit": limit}) # type: ignore
|
||||
|
||||
return pipeline # type: ignore
|
||||
|
||||
|
||||
def vector_search_stage(
|
||||
query_vector: List[float],
|
||||
search_field: str,
|
||||
index_name: str,
|
||||
top_k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
oversampling_factor: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]: # noqa: E501
|
||||
"""Vector Search Stage without Scores.
|
||||
|
||||
Scoring is applied later depending on strategy.
|
||||
vector search includes a vectorSearchScore that is typically used.
|
||||
hybrid uses Reciprocal Rank Fusion.
|
||||
|
||||
Args:
|
||||
query_vector: List of embedding vector
|
||||
search_field: Field in Collection containing embedding vectors
|
||||
index_name: Name of Atlas Vector Search Index tied to Collection
|
||||
top_k: Number of documents to return
|
||||
oversampling_factor: this times limit is the number of candidates
|
||||
filter: MQL match expression comparing an indexed field.
|
||||
Some operators are not supported.
|
||||
See `vectorSearch filter docs <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary defining the $vectorSearch
|
||||
"""
|
||||
stage = {
|
||||
"index": index_name,
|
||||
"path": search_field,
|
||||
"queryVector": query_vector,
|
||||
"numCandidates": top_k * oversampling_factor,
|
||||
"limit": top_k,
|
||||
}
|
||||
if filter:
|
||||
stage["filter"] = filter
|
||||
return {"$vectorSearch": stage}
|
||||
|
||||
|
||||
def combine_pipelines(
|
||||
pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str
|
||||
) -> None:
|
||||
"""Combines two aggregations into a single result set in-place."""
|
||||
if pipeline:
|
||||
pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}})
|
||||
else:
|
||||
pipeline.extend(stage)
|
||||
|
||||
|
||||
def reciprocal_rank_stage(
|
||||
score_field: str, penalty: float = 0, **kwargs: Any
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Stage adds Reciprocal Rank Fusion weighting.
|
||||
|
||||
First, it pushes documents retrieved from previous stage
|
||||
into a temporary sub-document. It then unwinds to establish
|
||||
the rank to each and applies the penalty.
|
||||
|
||||
Args:
|
||||
score_field: A unique string to identify the search being ranked
|
||||
penalty: A non-negative float.
|
||||
extra_fields: Any fields other than text_field that one wishes to keep.
|
||||
|
||||
Returns:
|
||||
RRF score := \frac{1}{rank + penalty} with rank in [1,2,..,n]
|
||||
"""
|
||||
|
||||
rrf_pipeline = [
|
||||
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
|
||||
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
|
||||
{
|
||||
"$addFields": {
|
||||
f"docs.{score_field}": {
|
||||
"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]
|
||||
},
|
||||
"docs.rank": "$rank",
|
||||
"_id": "$docs._id",
|
||||
}
|
||||
},
|
||||
{"$replaceRoot": {"newRoot": "$docs"}},
|
||||
]
|
||||
|
||||
return rrf_pipeline # type: ignore
|
||||
|
||||
|
||||
def final_hybrid_stage(
|
||||
scores_fields: List[str], limit: int, **kwargs: Any
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Sum weighted scores, sort, and apply limit.
|
||||
|
||||
Args:
|
||||
scores_fields: List of fields given to scores of vector and text searches
|
||||
limit: Number of documents to return
|
||||
|
||||
Returns:
|
||||
Final aggregation stages
|
||||
"""
|
||||
|
||||
return [
|
||||
{"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}},
|
||||
{"$replaceRoot": {"newRoot": "$docs"}},
|
||||
{"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}},
|
||||
{"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}},
|
||||
{"$sort": {"score": -1}},
|
||||
{"$limit": limit},
|
||||
]
|
@@ -0,0 +1,15 @@
|
||||
"""Search Retrievers of various types.
|
||||
|
||||
Use ``MongoDBAtlasVectorSearch.as_retriever(**)``
|
||||
to create MongoDB's core Vector Search Retriever.
|
||||
"""
|
||||
|
||||
from langchain_mongodb.retrievers.full_text_search import (
|
||||
MongoDBAtlasFullTextSearchRetriever,
|
||||
)
|
||||
from langchain_mongodb.retrievers.hybrid_search import MongoDBAtlasHybridSearchRetriever
|
||||
|
||||
__all__ = [
|
||||
"MongoDBAtlasHybridSearchRetriever",
|
||||
"MongoDBAtlasFullTextSearchRetriever",
|
||||
]
|
@@ -0,0 +1,59 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pymongo.collection import Collection
|
||||
|
||||
from langchain_mongodb.pipelines import text_search_stage
|
||||
from langchain_mongodb.utils import make_serializable
|
||||
|
||||
|
||||
class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
|
||||
"""Hybrid Search Retriever performs full-text searches
|
||||
using Lucene's standard (BM25) analyzer.
|
||||
"""
|
||||
|
||||
collection: Collection
|
||||
"""MongoDB Collection on an Atlas cluster"""
|
||||
search_index_name: str
|
||||
"""Atlas Search Index name"""
|
||||
search_field: str
|
||||
"""Collection field that contains the text to be searched. It must be indexed"""
|
||||
top_k: Optional[int] = None
|
||||
"""Number of documents to return. Default is no limit"""
|
||||
filter: Optional[Dict[str, Any]] = None
|
||||
"""(Optional) List of MQL match expression comparing an indexed field"""
|
||||
show_embeddings: float = False
|
||||
"""If true, returned Document metadata will include vectors"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents that are highest scoring / most similar to query.
|
||||
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callback handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
pipeline = text_search_stage( # type: ignore
|
||||
query=query,
|
||||
search_field=self.search_field,
|
||||
index_name=self.search_index_name,
|
||||
limit=self.top_k,
|
||||
filter=self.filter,
|
||||
)
|
||||
|
||||
# Execution
|
||||
cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
|
||||
|
||||
# Formatting
|
||||
docs = []
|
||||
for res in cursor:
|
||||
text = res.pop(self.search_field)
|
||||
make_serializable(res)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
@@ -0,0 +1,126 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pymongo.collection import Collection
|
||||
|
||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||
from langchain_mongodb.pipelines import (
|
||||
combine_pipelines,
|
||||
final_hybrid_stage,
|
||||
reciprocal_rank_stage,
|
||||
text_search_stage,
|
||||
vector_search_stage,
|
||||
)
|
||||
from langchain_mongodb.utils import make_serializable
|
||||
|
||||
|
||||
class MongoDBAtlasHybridSearchRetriever(BaseRetriever):
|
||||
"""Hybrid Search Retriever combines vector and full-text searches
|
||||
weighting them the via Reciprocal Rank Fusion (RRF) algorithm.
|
||||
|
||||
Increasing the vector_penalty will reduce the importance on the vector search.
|
||||
Increasing the fulltext_penalty will correspondingly reduce the fulltext score.
|
||||
For more on the algorithm,see
|
||||
https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking
|
||||
"""
|
||||
|
||||
vectorstore: MongoDBAtlasVectorSearch
|
||||
"""MongoDBAtlas VectorStore"""
|
||||
search_index_name: str
|
||||
"""Atlas Search Index (full-text) name"""
|
||||
top_k: int = 4
|
||||
"""Number of documents to return."""
|
||||
oversampling_factor: int = 10
|
||||
"""This times top_k is the number of candidates chosen at each step"""
|
||||
pre_filter: Optional[Dict[str, Any]] = None
|
||||
"""(Optional) Any MQL match expression comparing an indexed field"""
|
||||
post_filter: Optional[List[Dict[str, Any]]] = None
|
||||
"""(Optional) Pipeline of MongoDB aggregation stages for postprocessing."""
|
||||
vector_penalty: float = 60.0
|
||||
"""Penalty applied to vector search results in RRF: scores=1/(rank + penalty)"""
|
||||
fulltext_penalty: float = 60.0
|
||||
"""Penalty applied to full-text search results in RRF: scores=1/(rank + penalty)"""
|
||||
show_embeddings: float = False
|
||||
"""If true, returned Document metadata will include vectors."""
|
||||
|
||||
@property
|
||||
def collection(self) -> Collection:
|
||||
return self.vectorstore._collection
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents that are highest scoring / most similar to query.
|
||||
|
||||
Note that the same query is used in both searches,
|
||||
embedded for vector search, and as-is for full-text search.
|
||||
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callback handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
query_vector = self.vectorstore._embedding.embed_query(query)
|
||||
|
||||
scores_fields = ["vector_score", "fulltext_score"]
|
||||
pipeline: List[Any] = []
|
||||
|
||||
# First we build up the aggregation pipeline,
|
||||
# then it is passed to the server to execute
|
||||
# Vector Search stage
|
||||
vector_pipeline = [
|
||||
vector_search_stage(
|
||||
query_vector=query_vector,
|
||||
search_field=self.vectorstore._embedding_key,
|
||||
index_name=self.vectorstore._index_name,
|
||||
top_k=self.top_k,
|
||||
filter=self.pre_filter,
|
||||
oversampling_factor=self.oversampling_factor,
|
||||
)
|
||||
]
|
||||
vector_pipeline += reciprocal_rank_stage("vector_score", self.vector_penalty)
|
||||
|
||||
combine_pipelines(pipeline, vector_pipeline, self.collection.name)
|
||||
|
||||
# Full-Text Search stage
|
||||
text_pipeline = text_search_stage(
|
||||
query=query,
|
||||
search_field=self.vectorstore._text_key,
|
||||
index_name=self.search_index_name,
|
||||
limit=self.top_k,
|
||||
filter=self.pre_filter,
|
||||
)
|
||||
|
||||
text_pipeline.extend(
|
||||
reciprocal_rank_stage("fulltext_score", self.fulltext_penalty)
|
||||
)
|
||||
|
||||
combine_pipelines(pipeline, text_pipeline, self.collection.name)
|
||||
|
||||
# Sum and sort stage
|
||||
pipeline.extend(
|
||||
final_hybrid_stage(scores_fields=scores_fields, limit=self.top_k)
|
||||
)
|
||||
|
||||
# Removal of embeddings unless requested.
|
||||
if not self.show_embeddings:
|
||||
pipeline.append({"$project": {self.vectorstore._embedding_key: 0}})
|
||||
# Post filtering
|
||||
if self.post_filter is not None:
|
||||
pipeline.extend(self.post_filter)
|
||||
|
||||
# Execution
|
||||
cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
|
||||
|
||||
# Formatting
|
||||
docs = []
|
||||
for res in cursor:
|
||||
text = res.pop(self.vectorstore._text_key)
|
||||
# score = res.pop("score") # The score remains buried!
|
||||
make_serializable(res)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
@@ -1,6 +1,13 @@
|
||||
"""
|
||||
Tools for the Maximal Marginal Relevance (MMR) reranking.
|
||||
Duplicated from langchain_community to avoid cross-dependencies.
|
||||
"""Various Utility Functions
|
||||
|
||||
- Tools for handling bson.ObjectId
|
||||
|
||||
The help IDs live as ObjectId in MongoDB and str in Langchain and JSON.
|
||||
|
||||
|
||||
- Tools for the Maximal Marginal Relevance (MMR) reranking
|
||||
|
||||
These are duplicated from langchain_community to avoid cross-dependencies.
|
||||
|
||||
Functions "maximal_marginal_relevance" and "cosine_similarity"
|
||||
are duplicated in this utility respectively from modules:
|
||||
@@ -21,11 +28,6 @@ logger = logging.getLogger(__name__)
|
||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
|
||||
|
||||
class FailCode:
|
||||
INDEX_NOT_FOUND = 27
|
||||
INDEX_ALREADY_EXISTS = 68
|
||||
|
||||
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
@@ -65,7 +67,37 @@ def maximal_marginal_relevance(
|
||||
lambda_mult: float = 0.5,
|
||||
k: int = 4,
|
||||
) -> List[int]:
|
||||
"""Calculate maximal marginal relevance."""
|
||||
"""Compute Maximal Marginal Relevance (MMR).
|
||||
|
||||
MMR is a technique used to select documents that are both relevant to the query
|
||||
and diverse among themselves. This function returns the indices
|
||||
of the top-k embeddings that maximize the marginal relevance.
|
||||
|
||||
Args:
|
||||
query_embedding (np.ndarray): The embedding vector of the query.
|
||||
embedding_list (list of np.ndarray): A list containing the embedding vectors
|
||||
of the candidate documents.
|
||||
lambda_mult (float, optional): The trade-off parameter between
|
||||
relevance and diversity. Defaults to 0.5.
|
||||
k (int, optional): The number of embeddings to select. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
list of int: The indices of the embeddings that maximize the marginal relevance.
|
||||
|
||||
Notes:
|
||||
The Maximal Marginal Relevance (MMR) is computed using the following formula:
|
||||
|
||||
MMR = argmax_{D_i ∈ R \ S} [λ * Sim(D_i, Q) - (1 - λ) * max_{D_j ∈ S} Sim(D_i, D_j)]
|
||||
|
||||
where:
|
||||
- R is the set of candidate documents,
|
||||
- S is the set of selected documents,
|
||||
- Q is the query embedding,
|
||||
- Sim(D_i, Q) is the similarity between document D_i and the query,
|
||||
- Sim(D_i, D_j) is the similarity between documents D_i and D_j,
|
||||
- λ is the trade-off parameter.
|
||||
"""
|
||||
|
||||
if min(k, len(embedding_list)) <= 0:
|
||||
return []
|
||||
if query_embedding.ndim == 1:
|
||||
@@ -137,6 +169,7 @@ def make_serializable(
|
||||
obj: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Recursively cast values in a dict to a form able to json.dump"""
|
||||
|
||||
from bson import ObjectId
|
||||
|
||||
for k, v in obj.items():
|
||||
|
@@ -29,6 +29,7 @@ from langchain_mongodb.index import (
|
||||
create_vector_search_index,
|
||||
update_vector_search_index,
|
||||
)
|
||||
from langchain_mongodb.pipelines import vector_search_stage
|
||||
from langchain_mongodb.utils import (
|
||||
make_serializable,
|
||||
maximal_marginal_relevance,
|
||||
@@ -36,7 +37,6 @@ from langchain_mongodb.utils import (
|
||||
str_to_oid,
|
||||
)
|
||||
|
||||
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
|
||||
VST = TypeVar("VST", bound=VectorStore)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -45,19 +45,40 @@ DEFAULT_INSERT_BATCH_SIZE = 100_000
|
||||
|
||||
|
||||
class MongoDBAtlasVectorSearch(VectorStore):
|
||||
"""MongoDBAtlas vector store integration.
|
||||
"""MongoDB Atlas vector store integration.
|
||||
|
||||
MongoDBAtlasVectorSearch performs data operations on
|
||||
text, embeddings and arbitrary data. In addition to CRUD operations,
|
||||
the VectorStore provides Vector Search
|
||||
based on similarity of embedding vectors following the
|
||||
Hierarchical Navigable Small Worlds (HNSW) algorithm.
|
||||
|
||||
This supports a number of models to ascertain scores,
|
||||
"similarity" (default), "MMR", and "similarity_score_threshold".
|
||||
These are described in the search_type argument to as_retriever,
|
||||
which provides the Runnable.invoke(query) API, allowing
|
||||
MongoDBAtlasVectorSearch to be used within a chain.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-mongodb`` and ``pymongo`` and setup a MongoDB Atlas cluster (read through [this guide](https://www.mongodb.com/docs/manual/reference/connection-string/) to do so).
|
||||
* Set up a MongoDB Atlas cluster. The free tier M0 will allow you to start.
|
||||
Search Indexes are only available on Atlas, the fully managed cloud service,
|
||||
not the self-managed MongoDB.
|
||||
Follow [this guide](https://www.mongodb.com/basics/mongodb-atlas-tutorial)
|
||||
|
||||
* Create a Collection and a Vector Search Index.The procedure is described
|
||||
[here](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure).
|
||||
|
||||
* Install ``langchain-mongodb``
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -qU langchain-mongodb pymongo
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import getpass
|
||||
|
||||
MONGODB_ATLAS_CLUSTER_URI = getpass.getpass("MongoDB Atlas Cluster URI:")
|
||||
|
||||
Key init args — indexing params:
|
||||
@@ -127,7 +148,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
Search with filter:
|
||||
.. code-block:: python
|
||||
|
||||
results = vector_store.similarity_search(query="thud",k=1,filter={"bar": "baz"})
|
||||
results = vector_store.similarity_search(query="thud",k=1,post_filter=[{"bar": "baz"]})
|
||||
for doc in results:
|
||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
||||
|
||||
@@ -184,29 +205,24 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: Collection[MongoDBDocumentType],
|
||||
collection: Collection[Dict[str, Any]],
|
||||
embedding: Embeddings,
|
||||
*,
|
||||
index_name: str = "default",
|
||||
index_name: str = "vector_index",
|
||||
text_key: str = "text",
|
||||
embedding_key: str = "embedding",
|
||||
relevance_score_fn: str = "cosine",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
collection: MongoDB collection to add the texts to.
|
||||
embedding: Text embedding model to use.
|
||||
text_key: MongoDB field that will contain the text for each
|
||||
document.
|
||||
defaults to 'text'
|
||||
embedding_key: MongoDB field that will contain the embedding for
|
||||
each document.
|
||||
defaults to 'embedding'
|
||||
index_name: Name of the Atlas Search index.
|
||||
defaults to 'default'
|
||||
relevance_score_fn: The similarity score used for the index.
|
||||
defaults to 'cosine'
|
||||
Currently supported: 'euclidean', 'cosine', and 'dotProduct'.
|
||||
collection: MongoDB collection to add the texts to
|
||||
embedding: Text embedding model to use
|
||||
text_key: MongoDB field that will contain the text for each document
|
||||
index_name: Existing Atlas Vector Search Index
|
||||
embedding_key: Field that will contain the embedding for each document
|
||||
vector_index_name: Name of the Atlas Vector Search index
|
||||
relevance_score_fn: The similarity score used for the index
|
||||
Currently supported: 'euclidean', 'cosine', and 'dotProduct'
|
||||
"""
|
||||
self._collection = collection
|
||||
self._embedding = embedding
|
||||
@@ -412,69 +428,32 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
start = end
|
||||
return result_ids
|
||||
|
||||
def _similarity_search_with_score(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
pre_filter: Optional[Dict] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
include_embedding: bool = False,
|
||||
include_ids: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Core implementation."""
|
||||
params = {
|
||||
"queryVector": embedding,
|
||||
"path": self._embedding_key,
|
||||
"numCandidates": k * 10,
|
||||
"limit": k,
|
||||
"index": self._index_name,
|
||||
}
|
||||
if pre_filter:
|
||||
params["filter"] = pre_filter
|
||||
query = {"$vectorSearch": params}
|
||||
|
||||
pipeline = [
|
||||
query,
|
||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
]
|
||||
|
||||
# Exclude the embedding key from the return payload
|
||||
if not include_embedding:
|
||||
pipeline.append({"$project": {self._embedding_key: 0}})
|
||||
|
||||
if post_filter_pipeline is not None:
|
||||
pipeline.extend(post_filter_pipeline)
|
||||
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
||||
docs = []
|
||||
|
||||
for res in cursor:
|
||||
text = res.pop(self._text_key)
|
||||
score = res.pop("score")
|
||||
make_serializable(res)
|
||||
docs.append((Document(page_content=text, metadata=res), score))
|
||||
return docs
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
pre_filter: Optional[Dict] = None,
|
||||
pre_filter: Optional[Dict[str, Any]] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
oversampling_factor: int = 10,
|
||||
include_embeddings: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> List[Tuple[Document, float]]: # noqa: E501
|
||||
"""Return MongoDB documents most similar to the given query and their scores.
|
||||
|
||||
Uses the vectorSearch operator available in MongoDB Atlas Search.
|
||||
For more: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/
|
||||
Atlas Vector Search eliminates the need to run a separate
|
||||
search system alongside your database.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: (Optional) number of documents to return. Defaults to 4.
|
||||
pre_filter: (Optional) dictionary of argument(s) to prefilter document
|
||||
fields on.
|
||||
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
|
||||
following the vectorSearch stage.
|
||||
Args:
|
||||
query: Input text of semantic query
|
||||
k: Number of documents to return. Also known as top_k.
|
||||
pre_filter: List of MQL match expressions comparing an indexed field
|
||||
post_filter_pipeline: (Optional) Arbitrary pipeline of MongoDB
|
||||
aggregation stages applied after the search is complete.
|
||||
oversampling_factor: This times k is the number of candidates chosen
|
||||
at each step in the in HNSW Vector Search
|
||||
include_embeddings: If True, the embedding vector of each result
|
||||
will be included in metadata.
|
||||
kwargs: Additional arguments are specific to the search_type
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query and their scores.
|
||||
@@ -485,6 +464,8 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
k=k,
|
||||
pre_filter=pre_filter,
|
||||
post_filter_pipeline=post_filter_pipeline,
|
||||
oversampling_factor=oversampling_factor,
|
||||
include_embeddings=include_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
@@ -493,36 +474,46 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
pre_filter: Optional[Dict] = None,
|
||||
pre_filter: Optional[Dict[str, Any]] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
oversampling_factor: int = 10,
|
||||
include_scores: bool = False,
|
||||
include_embeddings: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> List[Document]: # noqa: E501
|
||||
"""Return MongoDB documents most similar to the given query.
|
||||
|
||||
Uses the vectorSearch operator available in MongoDB Atlas Search.
|
||||
For more: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/
|
||||
Atlas Vector Search eliminates the need to run a separate
|
||||
search system alongside your database.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
Args:
|
||||
query: Input text of semantic query
|
||||
k: (Optional) number of documents to return. Defaults to 4.
|
||||
pre_filter: (Optional) dictionary of argument(s) to prefilter document
|
||||
fields on.
|
||||
pre_filter: List of MQL match expressions comparing an indexed field
|
||||
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
|
||||
following the vectorSearch stage.
|
||||
to filter/process results after $vectorSearch.
|
||||
oversampling_factor: Multiple of k used when generating number of candidates
|
||||
at each step in the HNSW Vector Search,
|
||||
include_scores: If True, the query score of each result
|
||||
will be included in metadata.
|
||||
include_embeddings: If True, the embedding vector of each result
|
||||
will be included in metadata.
|
||||
kwargs: Additional arguments are specific to the search_type
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query and their scores.
|
||||
"""
|
||||
additional = kwargs.get("additional")
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query,
|
||||
k=k,
|
||||
pre_filter=pre_filter,
|
||||
post_filter_pipeline=post_filter_pipeline,
|
||||
oversampling_factor=oversampling_factor,
|
||||
include_embeddings=include_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if additional and "similarity_score" in additional:
|
||||
if include_scores:
|
||||
for doc, score in docs_and_scores:
|
||||
doc.metadata["score"] = score
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
@@ -533,7 +524,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
pre_filter: Optional[Dict] = None,
|
||||
pre_filter: Optional[Dict[str, Any]] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
@@ -548,19 +539,16 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
fetch_k: (Optional) number of documents to fetch before passing to MMR
|
||||
algorithm. Defaults to 20.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
pre_filter: (Optional) dictionary of argument(s) to prefilter on document
|
||||
fields.
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
pre_filter: List of MQL match expressions comparing an indexed field
|
||||
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
|
||||
following the vectorSearch stage.
|
||||
following the $vectorSearch stage.
|
||||
Returns:
|
||||
List of documents selected by maximal marginal relevance.
|
||||
"""
|
||||
query_embedding = self._embedding.embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding=query_embedding,
|
||||
embedding=self._embedding.embed_query(query),
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
@@ -575,7 +563,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[Dict]] = None,
|
||||
collection: Optional[Collection[MongoDBDocumentType]] = None,
|
||||
collection: Optional[Collection] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> MongoDBAtlasVectorSearch:
|
||||
@@ -588,6 +576,9 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
See `MongoDBAtlasVectorSearch` for kwargs and further description.
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from pymongo import MongoClient
|
||||
@@ -649,8 +640,9 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
pre_filter: Optional[Dict] = None,
|
||||
pre_filter: Optional[Dict[str, Any]] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
oversampling_factor: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]: # type: ignore
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@@ -666,10 +658,13 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
pre_filter: (Optional) dictionary of argument(s) to prefilter on document
|
||||
fields.
|
||||
pre_filter: (Optional) dictionary of arguments to filter document fields on.
|
||||
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
|
||||
following the vectorSearch stage.
|
||||
oversampling_factor: Multiple of k used when generating number
|
||||
of candidates in HNSW Vector Search,
|
||||
kwargs: Additional arguments are specific to the search_type
|
||||
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
@@ -678,7 +673,8 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
k=fetch_k,
|
||||
pre_filter=pre_filter,
|
||||
post_filter_pipeline=post_filter_pipeline,
|
||||
include_embedding=kwargs.pop("include_embedding", True),
|
||||
include_embeddings=True,
|
||||
oversampling_factor=oversampling_factor,
|
||||
**kwargs,
|
||||
)
|
||||
mmr_doc_indexes = maximal_marginal_relevance(
|
||||
@@ -696,31 +692,82 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
pre_filter: Optional[Dict[str, Any]] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
oversampling_factor: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance."""
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self.max_marginal_relevance_search_by_vector,
|
||||
self.max_marginal_relevance_search_by_vector, # type: ignore[arg-type]
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
pre_filter=pre_filter,
|
||||
post_filter_pipeline=post_filter_pipeline,
|
||||
oversampling_factor=oversampling_factor,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _similarity_search_with_score(
|
||||
self,
|
||||
query_vector: List[float],
|
||||
k: int = 4,
|
||||
pre_filter: Optional[Dict[str, Any]] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
oversampling_factor: int = 10,
|
||||
include_embeddings: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Core search routine. See external methods for details."""
|
||||
|
||||
# Atlas Vector Search, potentially with filter
|
||||
pipeline = [
|
||||
vector_search_stage(
|
||||
query_vector,
|
||||
self._embedding_key,
|
||||
self._index_name,
|
||||
k,
|
||||
pre_filter,
|
||||
oversampling_factor,
|
||||
**kwargs,
|
||||
),
|
||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
]
|
||||
|
||||
# Remove embeddings unless requested.
|
||||
if not include_embeddings:
|
||||
pipeline.append({"$project": {self._embedding_key: 0}})
|
||||
# Post-processing
|
||||
if post_filter_pipeline is not None:
|
||||
pipeline.extend(post_filter_pipeline)
|
||||
|
||||
# Execution
|
||||
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
||||
docs = []
|
||||
|
||||
# Format
|
||||
for res in cursor:
|
||||
text = res.pop(self._text_key)
|
||||
score = res.pop("score")
|
||||
make_serializable(res)
|
||||
docs.append((Document(page_content=text, metadata=res), score))
|
||||
return docs
|
||||
|
||||
def create_vector_search_index(
|
||||
self,
|
||||
dimensions: int,
|
||||
filters: Optional[List[Dict[str, str]]] = None,
|
||||
filters: Optional[List[str]] = None,
|
||||
update: bool = False,
|
||||
) -> None:
|
||||
"""Creates a MongoDB Atlas vectorSearch index for the VectorStore
|
||||
|
||||
Note**: This method may fail as it requires a MongoDB Atlas with
|
||||
these pre-requisites:
|
||||
- M10 cluster or higher
|
||||
- https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#prerequisites
|
||||
Note**: This method may fail as it requires a MongoDB Atlas with these
|
||||
`pre-requisites <https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#prerequisites>`.
|
||||
Currently, vector and full-text search index operations need to be
|
||||
performed manually on the Atlas UI for shared M0 clusters.
|
||||
|
||||
Args:
|
||||
dimensions (int): Number of dimensions in embedding
|
||||
|
Reference in New Issue
Block a user