mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
mongodb[minor]: MongoDB Partner Package -- Porting MongoDBAtlasVectorSearch (#17652)
This PR migrates the existing MongoDBAtlasVectorSearch abstraction from
the `langchain_community` section to the partners package section of the
codebase.
- [x] Run the partner package script as advised in the partner-packages
documentation.
- [x] Add Unit Tests
- [x] Migrate Integration Tests
- [x] Refactor `MongoDBAtlasVectorStore` (autogenerated) to
`MongoDBAtlasVectorSearch`
- [x] ~Remove~ deprecate the old `langchain_community` VectorStore
references.
## Additional Callouts
- Implemented the `delete` method
- Included any missing async function implementations
- `amax_marginal_relevance_search_by_vector`
- `adelete`
- Added new Unit Tests that test for functionality of
`MongoDBVectorSearch` methods
- Removed [`del
res[self._embedding_key]`](e0c81e1cb0/libs/community/langchain_community/vectorstores/mongodb_atlas.py (L218)
)
in `_similarity_search_with_score` function as it would make the
`maximal_marginal_relevance` function fail otherwise. The `Document`
needs to store the embedding key in metadata to work.
Checklist:
- [x] PR title: Please title your PR "package: description", where
"package" is whichever of langchain, community, core, experimental, etc.
is being modified. Use "docs: ..." for purely docs changes, "templates:
..." for template changes, "infra: ..." for CI changes.
- Example: "community: add foobar LLM"
- [x] PR message
- [x] Pass lint and test: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified to check that you're
passing lint and testing. See contribution guidelines for more
information on how to write/run tests, lint, etc:
https://python.langchain.com/docs/contributing/
- [x] Add tests and docs: If you're adding a new integration, please
include
1. Existing tests supplied in docs/docs do not change. Updated
docstrings for new functions like `delete`
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory. (This already exists)
If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
---------
Co-authored-by: Steven Silvester <steven.silvester@ieee.org>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
0
libs/partners/mongodb/tests/__init__.py
Normal file
0
libs/partners/mongodb/tests/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@@ -0,0 +1,170 @@
|
||||
"""Test MongoDB Atlas Vector Search functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from time import sleep
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pymongo import MongoClient
|
||||
from pymongo.collection import Collection
|
||||
|
||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||
from tests.utils import ConsistentFakeEmbeddings
|
||||
|
||||
INDEX_NAME = "langchain-test-index"
|
||||
NAMESPACE = "langchain_test_db.langchain_test_collection"
|
||||
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
||||
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
||||
DIMENSIONS = 1536
|
||||
TIMEOUT = 10.0
|
||||
INTERVAL = 0.5
|
||||
|
||||
|
||||
class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch):
|
||||
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
|
||||
"""Patched insert_texts that waits for data to be indexed before returning"""
|
||||
ids = super()._insert_texts(texts, metadatas)
|
||||
timeout = TIMEOUT
|
||||
while len(ids) != self.similarity_search("sandwich") and timeout >= 0:
|
||||
sleep(INTERVAL)
|
||||
timeout -= INTERVAL
|
||||
return ids
|
||||
|
||||
|
||||
def get_collection() -> Collection:
|
||||
test_client: MongoClient = MongoClient(CONNECTION_STRING)
|
||||
return test_client[DB_NAME][COLLECTION_NAME]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def collection() -> Collection:
|
||||
return get_collection()
|
||||
|
||||
|
||||
class TestMongoDBAtlasVectorSearch:
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
# insure the test collection is empty
|
||||
collection = get_collection()
|
||||
assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls) -> None:
|
||||
collection = get_collection()
|
||||
# delete all the documents in the collection
|
||||
collection.delete_many({}) # type: ignore[index]
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self) -> None:
|
||||
collection = get_collection()
|
||||
# delete all the documents in the collection
|
||||
collection.delete_many({}) # type: ignore[index]
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_openai(self) -> Embeddings:
|
||||
return ConsistentFakeEmbeddings(DIMENSIONS)
|
||||
|
||||
def test_from_documents(
|
||||
self, embedding_openai: Embeddings, collection: Any
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
documents = [
|
||||
Document(page_content="Dogs are tough.", metadata={"a": 1}),
|
||||
Document(page_content="Cats have fluff.", metadata={"b": 1}),
|
||||
Document(page_content="What is a sandwich?", metadata={"c": 1}),
|
||||
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
|
||||
]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
||||
documents,
|
||||
embedding_openai,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
# sleep(5) # waits for mongot to update Lucene's index
|
||||
output = vectorstore.similarity_search("Sandwich", k=1)
|
||||
assert len(output) == 1
|
||||
# Check for the presence of the metadata key
|
||||
assert any([key.page_content == output[0].page_content for key in documents])
|
||||
|
||||
def test_from_texts(self, embedding_openai: Embeddings, collection: Any) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"That fence is purple.",
|
||||
]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
# sleep(5) # waits for mongot to update Lucene's index
|
||||
output = vectorstore.similarity_search("Sandwich", k=1)
|
||||
assert len(output) == 1
|
||||
|
||||
def test_from_texts_with_metadatas(
|
||||
self, embedding_openai: Embeddings, collection: Any
|
||||
) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"The fence is purple.",
|
||||
]
|
||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||
metakeys = ["a", "b", "c", "d", "e"]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
metadatas=metadatas,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
# sleep(5) # waits for mongot to update Lucene's index
|
||||
output = vectorstore.similarity_search("Sandwich", k=1)
|
||||
assert len(output) == 1
|
||||
# Check for the presence of the metadata key
|
||||
assert any([key in output[0].metadata for key in metakeys])
|
||||
|
||||
def test_from_texts_with_metadatas_and_pre_filter(
|
||||
self, embedding_openai: Embeddings, collection: Any
|
||||
) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"The fence is purple.",
|
||||
]
|
||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
metadatas=metadatas,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
# sleep(5) # waits for mongot to update Lucene's index
|
||||
output = vectorstore.similarity_search(
|
||||
"Sandwich", k=1, pre_filter={"c": {"$lte": 0}}
|
||||
)
|
||||
assert output == []
|
||||
|
||||
def test_mmr(self, embedding_openai: Embeddings, collection: Any) -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
# sleep(5) # waits for mongot to update Lucene's index
|
||||
query = "foo"
|
||||
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0].page_content == "foo"
|
||||
assert output[1].page_content != "foo"
|
0
libs/partners/mongodb/tests/unit_tests/__init__.py
Normal file
0
libs/partners/mongodb/tests/unit_tests/__init__.py
Normal file
9
libs/partners/mongodb/tests/unit_tests/test_imports.py
Normal file
9
libs/partners/mongodb/tests/unit_tests/test_imports.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from langchain_mongodb import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"MongoDBAtlasVectorSearch",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
224
libs/partners/mongodb/tests/unit_tests/test_vectorstores.py
Normal file
224
libs/partners/mongodb/tests/unit_tests/test_vectorstores.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.results import DeleteResult, InsertManyResult
|
||||
|
||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||
from tests.utils import ConsistentFakeEmbeddings
|
||||
|
||||
INDEX_NAME = "langchain-test-index"
|
||||
NAMESPACE = "langchain_test_db.langchain_test_collection"
|
||||
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
||||
|
||||
|
||||
class MockCollection(Collection):
|
||||
"""Mocked Mongo Collection"""
|
||||
|
||||
_aggregate_result: List[Any]
|
||||
_insert_result: Optional[InsertManyResult]
|
||||
_data: List[Any]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data = []
|
||||
self._aggregate_result = []
|
||||
self._insert_result = None
|
||||
|
||||
def delete_many(self, *args, **kwargs) -> DeleteResult: # type: ignore
|
||||
old_len = len(self._data)
|
||||
self._data = []
|
||||
return DeleteResult({"n": old_len}, acknowledged=True)
|
||||
|
||||
def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore
|
||||
mongodb_inserts = [
|
||||
{"_id": str(uuid.uuid4()), "score": 1, **insert} for insert in to_insert
|
||||
]
|
||||
self._data.extend(mongodb_inserts)
|
||||
return self._insert_result or InsertManyResult(
|
||||
[k["_id"] for k in mongodb_inserts], acknowledged=True
|
||||
)
|
||||
|
||||
def aggregate(self, *args, **kwargs) -> List[Any]: # type: ignore
|
||||
return deepcopy(self._aggregate_result)
|
||||
|
||||
def count_documents(self, *args, **kwargs) -> int: # type: ignore
|
||||
return len(self._data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "FakeCollection"
|
||||
|
||||
|
||||
def get_collection() -> MockCollection:
|
||||
return MockCollection()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def collection() -> MockCollection:
|
||||
return get_collection()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def embedding_openai() -> Embeddings:
|
||||
return ConsistentFakeEmbeddings()
|
||||
|
||||
|
||||
def test_initialization(collection: Collection, embedding_openai: Embeddings) -> None:
|
||||
"""Test initialization of vector store class"""
|
||||
assert MongoDBAtlasVectorSearch(collection, embedding_openai)
|
||||
|
||||
|
||||
def test_init_from_texts(collection: Collection, embedding_openai: Embeddings) -> None:
|
||||
"""Test from_texts operation on an empty list"""
|
||||
assert MongoDBAtlasVectorSearch.from_texts(
|
||||
[], embedding_openai, collection=collection
|
||||
)
|
||||
|
||||
|
||||
class TestMongoDBAtlasVectorSearch:
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
# ensure the test collection is empty
|
||||
collection = get_collection()
|
||||
assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls) -> None:
|
||||
collection = get_collection()
|
||||
# delete all the documents in the collection
|
||||
collection.delete_many({}) # type: ignore[index]
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self) -> None:
|
||||
collection = get_collection()
|
||||
# delete all the documents in the collection
|
||||
collection.delete_many({}) # type: ignore[index]
|
||||
|
||||
def _validate_search(
|
||||
self,
|
||||
vectorstore: MongoDBAtlasVectorSearch,
|
||||
collection: MockCollection,
|
||||
search_term: str = "sandwich",
|
||||
page_content: str = "What is a sandwich?",
|
||||
metadata: Optional[Any] = 1,
|
||||
) -> None:
|
||||
collection._aggregate_result = list(
|
||||
filter(
|
||||
lambda x: search_term.lower() in x[vectorstore._text_key].lower(),
|
||||
collection._data,
|
||||
)
|
||||
)
|
||||
output = vectorstore.similarity_search("", k=1)
|
||||
assert output[0].page_content == page_content
|
||||
assert output[0].metadata.get("c") == metadata
|
||||
|
||||
def test_from_documents(
|
||||
self, embedding_openai: Embeddings, collection: MockCollection
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
documents = [
|
||||
Document(page_content="Dogs are tough.", metadata={"a": 1}),
|
||||
Document(page_content="Cats have fluff.", metadata={"b": 1}),
|
||||
Document(page_content="What is a sandwich?", metadata={"c": 1}),
|
||||
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
|
||||
]
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_documents(
|
||||
documents,
|
||||
embedding_openai,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
self._validate_search(
|
||||
vectorstore, collection, metadata=documents[2].metadata["c"]
|
||||
)
|
||||
|
||||
def test_from_texts(
|
||||
self, embedding_openai: Embeddings, collection: MockCollection
|
||||
) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"That fence is purple.",
|
||||
]
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
self._validate_search(vectorstore, collection, metadata=None)
|
||||
|
||||
def test_from_texts_with_metadatas(
|
||||
self, embedding_openai: Embeddings, collection: MockCollection
|
||||
) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"The fence is purple.",
|
||||
]
|
||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
metadatas=metadatas,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
self._validate_search(vectorstore, collection, metadata=metadatas[2]["c"])
|
||||
|
||||
def test_from_texts_with_metadatas_and_pre_filter(
|
||||
self, embedding_openai: Embeddings, collection: MockCollection
|
||||
) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"The fence is purple.",
|
||||
]
|
||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
metadatas=metadatas,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
collection._aggregate_result = list(
|
||||
filter(
|
||||
lambda x: "sandwich" in x[vectorstore._text_key].lower()
|
||||
and x.get("c") < 0,
|
||||
collection._data,
|
||||
)
|
||||
)
|
||||
output = vectorstore.similarity_search(
|
||||
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
|
||||
)
|
||||
assert output == []
|
||||
|
||||
def test_mmr(
|
||||
self, embedding_openai: Embeddings, collection: MockCollection
|
||||
) -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
query = "foo"
|
||||
self._validate_search(
|
||||
vectorstore,
|
||||
collection,
|
||||
search_term=query[0:2],
|
||||
page_content=query,
|
||||
metadata=None,
|
||||
)
|
||||
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0].page_content == "foo"
|
||||
assert output[1].page_content != "foo"
|
36
libs/partners/mongodb/tests/utils.py
Normal file
36
libs/partners/mongodb/tests/utils.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
class ConsistentFakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def __init__(self, dimensionality: int = 10) -> None:
|
||||
self.known_texts: List[str] = []
|
||||
self.dimensionality = dimensionality
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return consistent embeddings for each text seen so far."""
|
||||
out_vectors = []
|
||||
for text in texts:
|
||||
if text not in self.known_texts:
|
||||
self.known_texts.append(text)
|
||||
vector = [float(1.0)] * (self.dimensionality - 1) + [
|
||||
float(self.known_texts.index(text))
|
||||
]
|
||||
out_vectors.append(vector)
|
||||
return out_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||
one if the text is unknown."""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
Reference in New Issue
Block a user