mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-21 10:31:23 +00:00
mongodb: Add Hybrid and Full-Text Search Retrievers, release 0.2.0 (#25057)
## Description This pull-request extends the existing vector search strategies of MongoDBAtlasVectorSearch to include Hybrid (Reciprocal Rank Fusion) and Full-text via new Retrievers. There is a small breaking change in the form of the `prefilter` kwarg to search. For this, and because we have now added a great deal of features, including programmatic Index creation/deletion since 0.1.0, we plan to bump the version to 0.2.0. ### Checklist * Unit tests have been extended * formatting has been applied * One mypy error remains which will either go away in CI or be simplified. --------- Signed-off-by: Casey Clements <casey.clements@mongodb.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,10 +1,4 @@
|
||||
"""
|
||||
LangChain MongoDB Caches
|
||||
|
||||
Functions "_loads_generations" and "_dumps_generations"
|
||||
are duplicated in this utility from modules:
|
||||
- "libs/community/langchain_community/cache.py"
|
||||
"""
|
||||
"""LangChain MongoDB Caches."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
@@ -27,100 +21,6 @@ from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def _generate_mongo_client(connection_string: str) -> MongoClient:
|
||||
return MongoClient(
|
||||
connection_string,
|
||||
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
|
||||
)
|
||||
|
||||
|
||||
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
|
||||
"""
|
||||
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
|
||||
|
||||
Args:
|
||||
generations (RETURN_VAL_TYPE): A list of language model generations.
|
||||
|
||||
Returns:
|
||||
str: a single string representing a list of generations.
|
||||
|
||||
This function (+ its counterpart `_loads_generations`) rely on
|
||||
the dumps/loads pair with Reviver, so are able to deal
|
||||
with all subclasses of Generation.
|
||||
|
||||
Each item in the list can be `dumps`ed to a string,
|
||||
then we make the whole list of strings into a json-dumped.
|
||||
"""
|
||||
return json.dumps([dumps(_item) for _item in generations])
|
||||
|
||||
|
||||
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
|
||||
"""
|
||||
Deserialization of a string into a generic RETURN_VAL_TYPE
|
||||
(i.e. a sequence of `Generation`).
|
||||
|
||||
See `_dumps_generations`, the inverse of this function.
|
||||
|
||||
Args:
|
||||
generations_str (str): A string representing a list of generations.
|
||||
|
||||
Compatible with the legacy cache-blob format
|
||||
Does not raise exceptions for malformed entries, just logs a warning
|
||||
and returns none: the caller should be prepared for such a cache miss.
|
||||
|
||||
Returns:
|
||||
RETURN_VAL_TYPE: A list of generations.
|
||||
"""
|
||||
try:
|
||||
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
|
||||
return generations
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# deferring the (soft) handling to after the legacy-format attempt
|
||||
pass
|
||||
|
||||
try:
|
||||
gen_dicts = json.loads(generations_str)
|
||||
# not relying on `_load_generations_from_json` (which could disappear):
|
||||
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
|
||||
logger.warning(
|
||||
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
|
||||
)
|
||||
return generations
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _wait_until(
|
||||
predicate: Callable, success_description: Any, timeout: float = 10.0
|
||||
) -> None:
|
||||
"""Wait up to 10 seconds (by default) for predicate to be true.
|
||||
|
||||
E.g.:
|
||||
|
||||
wait_until(lambda: client.primary == ('a', 1),
|
||||
'connect to the primary')
|
||||
|
||||
If the lambda-expression isn't true after 10 seconds, we raise
|
||||
AssertionError("Didn't ever connect to the primary").
|
||||
|
||||
Returns the predicate's first true value.
|
||||
"""
|
||||
start = time.time()
|
||||
interval = min(float(timeout) / 100, 0.1)
|
||||
while True:
|
||||
retval = predicate()
|
||||
if retval:
|
||||
return retval
|
||||
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError("Didn't ever %s" % success_description)
|
||||
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
class MongoDBCache(BaseCache):
|
||||
"""MongoDB Atlas cache
|
||||
|
||||
@@ -216,7 +116,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
collection_name: str = "default",
|
||||
database_name: str = "default",
|
||||
index_name: str = "default",
|
||||
wait_until_ready: bool = False,
|
||||
wait_until_ready: Optional[float] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
):
|
||||
@@ -233,8 +133,8 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
Defaults to "default".
|
||||
index_name: Name of the Atlas Search index.
|
||||
defaults to 'default'
|
||||
wait_until_ready (bool): Block until MongoDB Atlas finishes indexing
|
||||
the stored text. Hard timeout of 10 seconds. Defaults to False.
|
||||
wait_until_ready (float): Wait this time for Atlas to finish indexing
|
||||
the stored text. Defaults to None.
|
||||
"""
|
||||
client = _generate_mongo_client(connection_string)
|
||||
self.collection = client[database_name][collection_name]
|
||||
@@ -272,7 +172,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
prompt: str,
|
||||
llm_string: str,
|
||||
return_val: RETURN_VAL_TYPE,
|
||||
wait_until_ready: Optional[bool] = None,
|
||||
wait_until_ready: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
self.add_texts(
|
||||
@@ -290,7 +190,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
return self.lookup(prompt, llm_string) == return_val
|
||||
|
||||
if wait:
|
||||
_wait_until(is_indexed, return_val)
|
||||
_wait_until(is_indexed, return_val, timeout=wait)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache that can take additional keyword arguments.
|
||||
@@ -302,3 +202,107 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
self.clear(llm_string="fake-model")
|
||||
"""
|
||||
self.collection.delete_many({**kwargs})
|
||||
|
||||
|
||||
def _generate_mongo_client(connection_string: str) -> MongoClient:
|
||||
return MongoClient(
|
||||
connection_string,
|
||||
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
|
||||
)
|
||||
|
||||
|
||||
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
|
||||
"""
|
||||
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
|
||||
|
||||
Args:
|
||||
generations (RETURN_VAL_TYPE): A list of language model generations.
|
||||
|
||||
Returns:
|
||||
str: a single string representing a list of generations.
|
||||
|
||||
This, and "_dumps_generations" are duplicated in this utility
|
||||
from modules: "libs/community/langchain_community/cache.py"
|
||||
|
||||
This function and its counterpart rely on
|
||||
the dumps/loads pair with Reviver, so are able to deal
|
||||
with all subclasses of Generation.
|
||||
|
||||
Each item in the list can be `dumps`ed to a string,
|
||||
then we make the whole list of strings into a json-dumped.
|
||||
"""
|
||||
return json.dumps([dumps(_item) for _item in generations])
|
||||
|
||||
|
||||
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
|
||||
"""
|
||||
Deserialization of a string into a generic RETURN_VAL_TYPE
|
||||
(i.e. a sequence of `Generation`).
|
||||
|
||||
Args:
|
||||
generations_str (str): A string representing a list of generations.
|
||||
|
||||
Returns:
|
||||
RETURN_VAL_TYPE: A list of generations.
|
||||
|
||||
|
||||
This function and its counterpart rely on
|
||||
the dumps/loads pair with Reviver, so are able to deal
|
||||
with all subclasses of Generation.
|
||||
|
||||
See `_dumps_generations`, the inverse of this function.
|
||||
|
||||
Compatible with the legacy cache-blob format
|
||||
Does not raise exceptions for malformed entries, just logs a warning
|
||||
and returns none: the caller should be prepared for such a cache miss.
|
||||
|
||||
|
||||
"""
|
||||
try:
|
||||
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
|
||||
return generations
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# deferring the (soft) handling to after the legacy-format attempt
|
||||
pass
|
||||
|
||||
try:
|
||||
gen_dicts = json.loads(generations_str)
|
||||
# not relying on `_load_generations_from_json` (which could disappear):
|
||||
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
|
||||
logger.warning(
|
||||
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
|
||||
)
|
||||
return generations
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _wait_until(
|
||||
predicate: Callable, success_description: Any, timeout: float = 10.0
|
||||
) -> None:
|
||||
"""Wait up to 10 seconds (by default) for predicate to be true.
|
||||
|
||||
E.g.:
|
||||
|
||||
wait_until(lambda: client.primary == ('a', 1),
|
||||
'connect to the primary')
|
||||
|
||||
If the lambda-expression isn't true after 10 seconds, we raise
|
||||
AssertionError("Didn't ever connect to the primary").
|
||||
|
||||
Returns the predicate's first true value.
|
||||
"""
|
||||
start = time.time()
|
||||
interval = min(float(timeout) / 100, 0.1)
|
||||
while True:
|
||||
retval = predicate()
|
||||
if retval:
|
||||
return retval
|
||||
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError("Didn't ever %s" % success_description)
|
||||
|
||||
time.sleep(interval)
|
||||
|
Reference in New Issue
Block a user