From 5738143d4b2bf6957373a467d53cc5c11012897c Mon Sep 17 00:00:00 2001 From: Qihui Xie Date: Wed, 14 Feb 2024 11:33:22 +0800 Subject: [PATCH] add mongodb_store (#13801) # Add MongoDB storage - **Description:** Add MongoDB Storage as an option for large doc store. Example usage: ```Python # Instantiate the MongodbStore with a MongoDB connection from langchain.storage import MongodbStore mongo_conn_str = "mongodb://localhost:27017/" mongodb_store = MongodbStore(mongo_conn_str, db_name="test-db", collection_name="test-collection") # Set values for keys doc1 = Document(page_content='test1') doc2 = Document(page_content='test2') mongodb_store.mset([("key1", doc1), ("key2", doc2)]) # Get values for keys values = mongodb_store.mget(["key1", "key2"]) # [doc1, doc2] # Iterate over keys for key in mongodb_store.yield_keys(): print(key) # Delete keys mongodb_store.mdelete(["key1", "key2"]) ``` - **Dependencies:** Use `mongomock` for integration test. --------- Co-authored-by: Bagatur Co-authored-by: Eugene Yurtsev --- .../langchain_community/storage/__init__.py | 2 + .../langchain_community/storage/mongodb.py | 126 ++++++++++++++++++ .../integration_tests/storage/test_mongodb.py | 73 ++++++++++ .../tests/unit_tests/storage/test_imports.py | 1 + .../tests/unit_tests/storage/test_mongodb.py | 11 ++ 5 files changed, 213 insertions(+) create mode 100644 libs/community/langchain_community/storage/mongodb.py create mode 100644 libs/community/tests/integration_tests/storage/test_mongodb.py create mode 100644 libs/community/tests/unit_tests/storage/test_mongodb.py diff --git a/libs/community/langchain_community/storage/__init__.py b/libs/community/langchain_community/storage/__init__.py index 494591b03c7..ffb95cab1d1 100644 --- a/libs/community/langchain_community/storage/__init__.py +++ b/libs/community/langchain_community/storage/__init__.py @@ -10,6 +10,7 @@ from langchain_community.storage.astradb import ( AstraDBByteStore, AstraDBStore, ) +from langchain_community.storage.mongodb import MongoDBStore from langchain_community.storage.redis import RedisStore from langchain_community.storage.upstash_redis import ( UpstashRedisByteStore, @@ -19,6 +20,7 @@ from langchain_community.storage.upstash_redis import ( __all__ = [ "AstraDBStore", "AstraDBByteStore", + "MongoDBStore", "RedisStore", "UpstashRedisByteStore", "UpstashRedisStore", diff --git a/libs/community/langchain_community/storage/mongodb.py b/libs/community/langchain_community/storage/mongodb.py new file mode 100644 index 00000000000..97447f83217 --- /dev/null +++ b/libs/community/langchain_community/storage/mongodb.py @@ -0,0 +1,126 @@ +from typing import Iterator, List, Optional, Sequence, Tuple + +from langchain_core.documents import Document +from langchain_core.stores import BaseStore + + +class MongoDBStore(BaseStore[str, Document]): + """BaseStore implementation using MongoDB as the underlying store. + + Examples: + Create a MongoDBStore instance and perform operations on it: + + .. code-block:: python + + # Instantiate the MongoDBStore with a MongoDB connection + from langchain.storage import MongoDBStore + + mongo_conn_str = "mongodb://localhost:27017/" + mongodb_store = MongoDBStore(mongo_conn_str, db_name="test-db", + collection_name="test-collection") + + # Set values for keys + doc1 = Document(...) + doc2 = Document(...) + mongodb_store.mset([("key1", doc1), ("key2", doc2)]) + + # Get values for keys + values = mongodb_store.mget(["key1", "key2"]) + # [doc1, doc2] + + # Iterate over keys + for key in mongodb_store.yield_keys(): + print(key) + + # Delete keys + mongodb_store.mdelete(["key1", "key2"]) + """ + + def __init__( + self, + connection_string: str, + db_name: str, + collection_name: str, + *, + client_kwargs: Optional[dict] = None, + ) -> None: + """Initialize the MongoDBStore with a MongoDB connection string. + + Args: + connection_string (str): MongoDB connection string + db_name (str): name to use + collection_name (str): collection name to use + client_kwargs (dict): Keyword arguments to pass to the Mongo client + """ + try: + from pymongo import MongoClient + except ImportError as e: + raise ImportError( + "The MongoDBStore requires the pymongo library to be " + "installed. " + "pip install pymongo" + ) from e + + if not connection_string: + raise ValueError("connection_string must be provided.") + if not db_name: + raise ValueError("db_name must be provided.") + if not collection_name: + raise ValueError("collection_name must be provided.") + + self.client = MongoClient(connection_string, **(client_kwargs or {})) + self.collection = self.client[db_name][collection_name] + + def mget(self, keys: Sequence[str]) -> List[Optional[Document]]: + """Get the list of documents associated with the given keys. + + Args: + keys (list[str]): A list of keys representing Document IDs.. + + Returns: + list[Document]: A list of Documents corresponding to the provided + keys, where each Document is either retrieved successfully or + represented as None if not found. + """ + result = self.collection.find({"_id": {"$in": keys}}) + result_dict = {doc["_id"]: Document(**doc["value"]) for doc in result} + return [result_dict.get(key) for key in keys] + + def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None: + """Set the given key-value pairs. + + Args: + key_value_pairs (list[tuple[str, Document]]): A list of id-document + pairs. + Returns: + None + """ + from pymongo import UpdateOne + + updates = [{"_id": k, "value": v.__dict__} for k, v in key_value_pairs] + self.collection.bulk_write( + [UpdateOne({"_id": u["_id"]}, {"$set": u}, upsert=True) for u in updates] + ) + + def mdelete(self, keys: Sequence[str]) -> None: + """Delete the given ids. + + Args: + keys (list[str]): A list of keys representing Document IDs.. + """ + self.collection.delete_many({"_id": {"$in": keys}}) + + def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]: + """Yield keys in the store. + + Args: + prefix (str): prefix of keys to retrieve. + """ + if prefix is None: + for doc in self.collection.find(projection=["_id"]): + yield doc["_id"] + else: + for doc in self.collection.find( + {"_id": {"$regex": f"^{prefix}"}}, projection=["_id"] + ): + yield doc["_id"] diff --git a/libs/community/tests/integration_tests/storage/test_mongodb.py b/libs/community/tests/integration_tests/storage/test_mongodb.py new file mode 100644 index 00000000000..44062e994e0 --- /dev/null +++ b/libs/community/tests/integration_tests/storage/test_mongodb.py @@ -0,0 +1,73 @@ +from typing import Generator + +import pytest +from langchain_core.documents import Document + +from langchain_community.storage.mongodb import MongoDBStore + +pytest.importorskip("pymongo") + + +@pytest.fixture +def mongo_store() -> Generator: + import mongomock + + # mongomock creates a mock MongoDB instance for testing purposes + with mongomock.patch(servers=(("localhost", 27017),)): + yield MongoDBStore("mongodb://localhost:27017/", "test_db", "test_collection") + + +def test_mset_and_mget(mongo_store: MongoDBStore) -> None: + doc1 = Document(page_content="doc1") + doc2 = Document(page_content="doc2") + + # Set documents in the store + mongo_store.mset([("key1", doc1), ("key2", doc2)]) + + # Get documents from the store + retrieved_docs = mongo_store.mget(["key1", "key2"]) + + assert retrieved_docs[0] and retrieved_docs[0].page_content == "doc1" + assert retrieved_docs[1] and retrieved_docs[1].page_content == "doc2" + + +def test_yield_keys(mongo_store: MongoDBStore) -> None: + mongo_store.mset( + [ + ("key1", Document(page_content="doc1")), + ("key2", Document(page_content="doc2")), + ("another_key", Document(page_content="other")), + ] + ) + + # Test without prefix + keys = list(mongo_store.yield_keys()) + assert set(keys) == {"key1", "key2", "another_key"} + + # Test with prefix + keys_with_prefix = list(mongo_store.yield_keys(prefix="key")) + assert set(keys_with_prefix) == {"key1", "key2"} + + +def test_mdelete(mongo_store: MongoDBStore) -> None: + mongo_store.mset( + [ + ("key1", Document(page_content="doc1")), + ("key2", Document(page_content="doc2")), + ] + ) + # Delete single document + mongo_store.mdelete(["key1"]) + remaining_docs = list(mongo_store.yield_keys()) + assert "key1" not in remaining_docs + assert "key2" in remaining_docs + + # Delete multiple documents + mongo_store.mdelete(["key2"]) + remaining_docs = list(mongo_store.yield_keys()) + assert len(remaining_docs) == 0 + + +def test_init_errors() -> None: + with pytest.raises(ValueError): + MongoDBStore("", "", "") diff --git a/libs/community/tests/unit_tests/storage/test_imports.py b/libs/community/tests/unit_tests/storage/test_imports.py index 27bcec56d4b..21f79d464e7 100644 --- a/libs/community/tests/unit_tests/storage/test_imports.py +++ b/libs/community/tests/unit_tests/storage/test_imports.py @@ -3,6 +3,7 @@ from langchain_community.storage import __all__ EXPECTED_ALL = [ "AstraDBStore", "AstraDBByteStore", + "MongoDBStore", "RedisStore", "UpstashRedisByteStore", "UpstashRedisStore", diff --git a/libs/community/tests/unit_tests/storage/test_mongodb.py b/libs/community/tests/unit_tests/storage/test_mongodb.py new file mode 100644 index 00000000000..a1dc9ea9068 --- /dev/null +++ b/libs/community/tests/unit_tests/storage/test_mongodb.py @@ -0,0 +1,11 @@ +"""Light weight unit test that attempts to import MongodbStore. + +The actual code is tested in integration tests. + +This test is intended to catch errors in the import process. +""" + + +def test_import_storage() -> None: + """Attempt to import storage modules.""" + from langchain_community.storage.mongodb import MongoDBStore # noqa