mongodb: Add Hybrid and Full-Text Search Retrievers, release 0.2.0 (#25057)

## Description

This pull-request extends the existing vector search strategies of
MongoDBAtlasVectorSearch to include Hybrid (Reciprocal Rank Fusion) and
Full-text via new Retrievers.

There is a small breaking change in the form of the `prefilter` kwarg to
search. For this, and because we have now added a great deal of
features, including programmatic Index creation/deletion since 0.1.0, we
plan to bump the version to 0.2.0.

### Checklist
* Unit tests have been extended
* formatting has been applied
* One mypy error remains which will either go away in CI or be
simplified.

---------

Signed-off-by: Casey Clements <casey.clements@mongodb.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Casey Clements
2024-08-07 16:10:29 -04:00
committed by GitHub
parent f337408b0f
commit 6e9a8b188f
22 changed files with 1749 additions and 508 deletions

View File

@@ -1,10 +1,4 @@
"""
LangChain MongoDB Caches
Functions "_loads_generations" and "_dumps_generations"
are duplicated in this utility from modules:
- "libs/community/langchain_community/cache.py"
"""
"""LangChain MongoDB Caches."""
import json
import logging
@@ -27,100 +21,6 @@ from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
logger = logging.getLogger(__file__)
def _generate_mongo_client(connection_string: str) -> MongoClient:
return MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
"""
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
Args:
generations (RETURN_VAL_TYPE): A list of language model generations.
Returns:
str: a single string representing a list of generations.
This function (+ its counterpart `_loads_generations`) rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
Each item in the list can be `dumps`ed to a string,
then we make the whole list of strings into a json-dumped.
"""
return json.dumps([dumps(_item) for _item in generations])
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
"""
Deserialization of a string into a generic RETURN_VAL_TYPE
(i.e. a sequence of `Generation`).
See `_dumps_generations`, the inverse of this function.
Args:
generations_str (str): A string representing a list of generations.
Compatible with the legacy cache-blob format
Does not raise exceptions for malformed entries, just logs a warning
and returns none: the caller should be prepared for such a cache miss.
Returns:
RETURN_VAL_TYPE: A list of generations.
"""
try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
return generations
except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt
pass
try:
gen_dicts = json.loads(generations_str)
# not relying on `_load_generations_from_json` (which could disappear):
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
logger.warning(
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
)
return generations
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
)
return None
def _wait_until(
predicate: Callable, success_description: Any, timeout: float = 10.0
) -> None:
"""Wait up to 10 seconds (by default) for predicate to be true.
E.g.:
wait_until(lambda: client.primary == ('a', 1),
'connect to the primary')
If the lambda-expression isn't true after 10 seconds, we raise
AssertionError("Didn't ever connect to the primary").
Returns the predicate's first true value.
"""
start = time.time()
interval = min(float(timeout) / 100, 0.1)
while True:
retval = predicate()
if retval:
return retval
if time.time() - start > timeout:
raise TimeoutError("Didn't ever %s" % success_description)
time.sleep(interval)
class MongoDBCache(BaseCache):
"""MongoDB Atlas cache
@@ -216,7 +116,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
collection_name: str = "default",
database_name: str = "default",
index_name: str = "default",
wait_until_ready: bool = False,
wait_until_ready: Optional[float] = None,
score_threshold: Optional[float] = None,
**kwargs: Dict[str, Any],
):
@@ -233,8 +133,8 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
Defaults to "default".
index_name: Name of the Atlas Search index.
defaults to 'default'
wait_until_ready (bool): Block until MongoDB Atlas finishes indexing
the stored text. Hard timeout of 10 seconds. Defaults to False.
wait_until_ready (float): Wait this time for Atlas to finish indexing
the stored text. Defaults to None.
"""
client = _generate_mongo_client(connection_string)
self.collection = client[database_name][collection_name]
@@ -272,7 +172,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
prompt: str,
llm_string: str,
return_val: RETURN_VAL_TYPE,
wait_until_ready: Optional[bool] = None,
wait_until_ready: Optional[float] = None,
) -> None:
"""Update cache based on prompt and llm_string."""
self.add_texts(
@@ -290,7 +190,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
return self.lookup(prompt, llm_string) == return_val
if wait:
_wait_until(is_indexed, return_val)
_wait_until(is_indexed, return_val, timeout=wait)
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments.
@@ -302,3 +202,107 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
self.clear(llm_string="fake-model")
"""
self.collection.delete_many({**kwargs})
def _generate_mongo_client(connection_string: str) -> MongoClient:
return MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
"""
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
Args:
generations (RETURN_VAL_TYPE): A list of language model generations.
Returns:
str: a single string representing a list of generations.
This, and "_dumps_generations" are duplicated in this utility
from modules: "libs/community/langchain_community/cache.py"
This function and its counterpart rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
Each item in the list can be `dumps`ed to a string,
then we make the whole list of strings into a json-dumped.
"""
return json.dumps([dumps(_item) for _item in generations])
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
"""
Deserialization of a string into a generic RETURN_VAL_TYPE
(i.e. a sequence of `Generation`).
Args:
generations_str (str): A string representing a list of generations.
Returns:
RETURN_VAL_TYPE: A list of generations.
This function and its counterpart rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
See `_dumps_generations`, the inverse of this function.
Compatible with the legacy cache-blob format
Does not raise exceptions for malformed entries, just logs a warning
and returns none: the caller should be prepared for such a cache miss.
"""
try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
return generations
except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt
pass
try:
gen_dicts = json.loads(generations_str)
# not relying on `_load_generations_from_json` (which could disappear):
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
logger.warning(
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
)
return generations
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
)
return None
def _wait_until(
predicate: Callable, success_description: Any, timeout: float = 10.0
) -> None:
"""Wait up to 10 seconds (by default) for predicate to be true.
E.g.:
wait_until(lambda: client.primary == ('a', 1),
'connect to the primary')
If the lambda-expression isn't true after 10 seconds, we raise
AssertionError("Didn't ever connect to the primary").
Returns the predicate's first true value.
"""
start = time.time()
interval = min(float(timeout) / 100, 0.1)
while True:
retval = predicate()
if retval:
return retval
if time.time() - start > timeout:
raise TimeoutError("Didn't ever %s" % success_description)
time.sleep(interval)