mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
mongodb[minor]: MongoDB Partner Package -- Porting MongoDBAtlasVectorSearch (#17652)
This PR migrates the existing MongoDBAtlasVectorSearch abstraction from
the `langchain_community` section to the partners package section of the
codebase.
- [x] Run the partner package script as advised in the partner-packages
documentation.
- [x] Add Unit Tests
- [x] Migrate Integration Tests
- [x] Refactor `MongoDBAtlasVectorStore` (autogenerated) to
`MongoDBAtlasVectorSearch`
- [x] ~Remove~ deprecate the old `langchain_community` VectorStore
references.
## Additional Callouts
- Implemented the `delete` method
- Included any missing async function implementations
- `amax_marginal_relevance_search_by_vector`
- `adelete`
- Added new Unit Tests that test for functionality of
`MongoDBVectorSearch` methods
- Removed [`del
res[self._embedding_key]`](e0c81e1cb0/libs/community/langchain_community/vectorstores/mongodb_atlas.py (L218)
)
in `_similarity_search_with_score` function as it would make the
`maximal_marginal_relevance` function fail otherwise. The `Document`
needs to store the embedding key in metadata to work.
Checklist:
- [x] PR title: Please title your PR "package: description", where
"package" is whichever of langchain, community, core, experimental, etc.
is being modified. Use "docs: ..." for purely docs changes, "templates:
..." for template changes, "infra: ..." for CI changes.
- Example: "community: add foobar LLM"
- [x] PR message
- [x] Pass lint and test: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified to check that you're
passing lint and testing. See contribution guidelines for more
information on how to write/run tests, lint, etc:
https://python.langchain.com/docs/contributing/
- [x] Add tests and docs: If you're adding a new integration, please
include
1. Existing tests supplied in docs/docs do not change. Updated
docstrings for new functions like `delete`
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory. (This already exists)
If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
---------
Co-authored-by: Steven Silvester <steven.silvester@ieee.org>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
412148773c
commit
72bfc1d3db
@ -8,10 +8,10 @@
|
|||||||
|
|
||||||
See [detail configuration instructions](/docs/integrations/vectorstores/mongodb_atlas).
|
See [detail configuration instructions](/docs/integrations/vectorstores/mongodb_atlas).
|
||||||
|
|
||||||
We need to install `pymongo` python package.
|
We need to install `langchain-mongodb` python package.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install pymongo
|
pip install langchain-mongodb
|
||||||
```
|
```
|
||||||
|
|
||||||
## Vector Store
|
## Vector Store
|
||||||
@ -19,6 +19,6 @@ pip install pymongo
|
|||||||
See a [usage example](/docs/integrations/vectorstores/mongodb_atlas).
|
See a [usage example](/docs/integrations/vectorstores/mongodb_atlas).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
|
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from langchain_core._api.deprecation import deprecated
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
@ -32,6 +33,11 @@ logger = logging.getLogger(__name__)
|
|||||||
DEFAULT_INSERT_BATCH_SIZE = 100
|
DEFAULT_INSERT_BATCH_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.0.25",
|
||||||
|
removal="0.2.0",
|
||||||
|
alternative_import="langchain_mongodb.MongoDBAtlasVectorSearch",
|
||||||
|
)
|
||||||
class MongoDBAtlasVectorSearch(VectorStore):
|
class MongoDBAtlasVectorSearch(VectorStore):
|
||||||
"""`MongoDB Atlas Vector Search` vector store.
|
"""`MongoDB Atlas Vector Search` vector store.
|
||||||
|
|
||||||
|
1
libs/partners/mongodb/.gitignore
vendored
Normal file
1
libs/partners/mongodb/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
__pycache__
|
21
libs/partners/mongodb/LICENSE
Normal file
21
libs/partners/mongodb/LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 LangChain, Inc.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
57
libs/partners/mongodb/Makefile
Normal file
57
libs/partners/mongodb/Makefile
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||||
|
|
||||||
|
# Default target executed when no arguments are given to make.
|
||||||
|
all: help
|
||||||
|
|
||||||
|
# Define a variable for the test file path.
|
||||||
|
TEST_FILE ?= tests/unit_tests/
|
||||||
|
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||||
|
|
||||||
|
test tests integration_test integration_tests:
|
||||||
|
poetry run pytest $(TEST_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
######################
|
||||||
|
# LINTING AND FORMATTING
|
||||||
|
######################
|
||||||
|
|
||||||
|
# Define a variable for Python and notebook files.
|
||||||
|
PYTHON_FILES=.
|
||||||
|
MYPY_CACHE=.mypy_cache
|
||||||
|
lint format: PYTHON_FILES=.
|
||||||
|
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/mongodb --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||||
|
lint_package: PYTHON_FILES=langchain_mongodb
|
||||||
|
lint_tests: PYTHON_FILES=tests
|
||||||
|
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||||
|
|
||||||
|
lint lint_diff lint_package lint_tests:
|
||||||
|
poetry run ruff .
|
||||||
|
poetry run ruff format $(PYTHON_FILES) --diff
|
||||||
|
poetry run ruff --select I $(PYTHON_FILES)
|
||||||
|
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||||
|
|
||||||
|
format format_diff:
|
||||||
|
poetry run ruff format $(PYTHON_FILES)
|
||||||
|
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||||
|
|
||||||
|
spell_check:
|
||||||
|
poetry run codespell --toml pyproject.toml
|
||||||
|
|
||||||
|
spell_fix:
|
||||||
|
poetry run codespell --toml pyproject.toml -w
|
||||||
|
|
||||||
|
check_imports: $(shell find langchain_mongodb -name '*.py')
|
||||||
|
poetry run python ./scripts/check_imports.py $^
|
||||||
|
|
||||||
|
######################
|
||||||
|
# HELP
|
||||||
|
######################
|
||||||
|
|
||||||
|
help:
|
||||||
|
@echo '----'
|
||||||
|
@echo 'check_imports - check imports'
|
||||||
|
@echo 'format - run code formatters'
|
||||||
|
@echo 'lint - run linters'
|
||||||
|
@echo 'test - run unit tests'
|
||||||
|
@echo 'tests - run unit tests'
|
||||||
|
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
40
libs/partners/mongodb/README.md
Normal file
40
libs/partners/mongodb/README.md
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# langchain-mongodb
|
||||||
|
|
||||||
|
# Installation
|
||||||
|
```
|
||||||
|
pip install -U langchain-mongodb
|
||||||
|
```
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
- See [integrations doc](../../../docs/docs/integrations/vectorstores/mongodb.ipynb) for more in-depth usage instructions.
|
||||||
|
- See [Getting Started with the LangChain Integration](https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/langchain/#get-started-with-the-langchain-integration) for a walkthrough on using your first LangChain implementation with MongoDB Atlas.
|
||||||
|
|
||||||
|
## Using MongoDBAtlasVectorSearch
|
||||||
|
```python
|
||||||
|
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||||
|
|
||||||
|
# Pull MongoDB Atlas URI from environment variables
|
||||||
|
MONGODB_ATLAS_CLUSTER_URI = os.environ.get("MONGODB_ATLAS_CLUSTER_URI")
|
||||||
|
|
||||||
|
DB_NAME = "langchain_db"
|
||||||
|
COLLECTION_NAME = "test"
|
||||||
|
ATLAS_VECTOR_SEARCH_INDEX_NAME = "index_name"
|
||||||
|
MONGODB_COLLECTION = client[DB_NAME][COLLECITON_NAME]
|
||||||
|
|
||||||
|
# Create the vector search via `from_connection_string`
|
||||||
|
vector_search = MongoDBAtlasVectorSearch.from_connection_string(
|
||||||
|
MONGODB_ATLAS_CLUSTER_URI,
|
||||||
|
DB_NAME + "." + COLLECTION_NAME,
|
||||||
|
OpenAIEmbeddings(disallowed_special=()),
|
||||||
|
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize MongoDB python client
|
||||||
|
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
|
||||||
|
# Create the vector search via instantiation
|
||||||
|
vector_search_2 = MongoDBAtlasVectorSearch(
|
||||||
|
collection=MONGODB_COLLECTION,
|
||||||
|
embeddings=OpenAIEmbeddings(disallowed_special=()),
|
||||||
|
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
|
||||||
|
)
|
||||||
|
```
|
7
libs/partners/mongodb/langchain_mongodb/__init__.py
Normal file
7
libs/partners/mongodb/langchain_mongodb/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from langchain_mongodb.vectorstores import (
|
||||||
|
MongoDBAtlasVectorSearch,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MongoDBAtlasVectorSearch",
|
||||||
|
]
|
0
libs/partners/mongodb/langchain_mongodb/py.typed
Normal file
0
libs/partners/mongodb/langchain_mongodb/py.typed
Normal file
87
libs/partners/mongodb/langchain_mongodb/utils.py
Normal file
87
libs/partners/mongodb/langchain_mongodb/utils.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
Tools for the Maximal Marginal Relevance (MMR) reranking.
|
||||||
|
Duplicated from langchain_community to avoid cross-dependencies.
|
||||||
|
|
||||||
|
Functions "maximal_marginal_relevance" and "cosine_similarity"
|
||||||
|
are duplicated in this utility respectively from modules:
|
||||||
|
- "libs/community/langchain_community/vectorstores/utils.py"
|
||||||
|
- "libs/community/langchain_community/utils/math.py"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||||
|
"""Row-wise cosine similarity between two equal-width matrices."""
|
||||||
|
if len(X) == 0 or len(Y) == 0:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
X = np.array(X)
|
||||||
|
Y = np.array(Y)
|
||||||
|
if X.shape[1] != Y.shape[1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
||||||
|
f"and Y has shape {Y.shape}."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import simsimd as simd # type: ignore
|
||||||
|
|
||||||
|
X = np.array(X, dtype=np.float32)
|
||||||
|
Y = np.array(Y, dtype=np.float32)
|
||||||
|
Z = 1 - simd.cdist(X, Y, metric="cosine")
|
||||||
|
if isinstance(Z, float):
|
||||||
|
return np.array([Z])
|
||||||
|
return Z
|
||||||
|
except ImportError:
|
||||||
|
logger.info(
|
||||||
|
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
|
||||||
|
"to use simsimd please install with `pip install simsimd`."
|
||||||
|
)
|
||||||
|
X_norm = np.linalg.norm(X, axis=1)
|
||||||
|
Y_norm = np.linalg.norm(Y, axis=1)
|
||||||
|
# Ignore divide by zero errors run time warnings as those are handled below.
|
||||||
|
with np.errstate(divide="ignore", invalid="ignore"):
|
||||||
|
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||||
|
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
|
||||||
|
def maximal_marginal_relevance(
|
||||||
|
query_embedding: np.ndarray,
|
||||||
|
embedding_list: list,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
k: int = 4,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Calculate maximal marginal relevance."""
|
||||||
|
if min(k, len(embedding_list)) <= 0:
|
||||||
|
return []
|
||||||
|
if query_embedding.ndim == 1:
|
||||||
|
query_embedding = np.expand_dims(query_embedding, axis=0)
|
||||||
|
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
||||||
|
most_similar = int(np.argmax(similarity_to_query))
|
||||||
|
idxs = [most_similar]
|
||||||
|
selected = np.array([embedding_list[most_similar]])
|
||||||
|
while len(idxs) < min(k, len(embedding_list)):
|
||||||
|
best_score = -np.inf
|
||||||
|
idx_to_add = -1
|
||||||
|
similarity_to_selected = cosine_similarity(embedding_list, selected)
|
||||||
|
for i, query_score in enumerate(similarity_to_query):
|
||||||
|
if i in idxs:
|
||||||
|
continue
|
||||||
|
redundant_score = max(similarity_to_selected[i])
|
||||||
|
equation_score = (
|
||||||
|
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
|
||||||
|
)
|
||||||
|
if equation_score > best_score:
|
||||||
|
best_score = equation_score
|
||||||
|
idx_to_add = i
|
||||||
|
idxs.append(idx_to_add)
|
||||||
|
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
||||||
|
return idxs
|
463
libs/partners/mongodb/langchain_mongodb/vectorstores.py
Normal file
463
libs/partners/mongodb/langchain_mongodb/vectorstores.py
Normal file
@ -0,0 +1,463 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from importlib.metadata import version
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.driver_info import DriverInfo
|
||||||
|
|
||||||
|
from langchain_mongodb.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
|
||||||
|
VST = TypeVar("VST", bound=VectorStore)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_INSERT_BATCH_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBAtlasVectorSearch(VectorStore):
|
||||||
|
"""`MongoDB Atlas Vector Search` vector store.
|
||||||
|
|
||||||
|
To use, you should have both:
|
||||||
|
- the ``pymongo`` python package installed
|
||||||
|
- a connection string associated with a MongoDB Atlas Cluster having deployed an
|
||||||
|
Atlas Search index
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
|
||||||
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from pymongo import MongoClient
|
||||||
|
|
||||||
|
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||||
|
collection = mongo_client["<db_name>"]["<collection_name>"]
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
vectorstore = MongoDBAtlasVectorSearch(collection, embeddings)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection: Collection[MongoDBDocumentType],
|
||||||
|
embedding: Embeddings,
|
||||||
|
*,
|
||||||
|
index_name: str = "default",
|
||||||
|
text_key: str = "text",
|
||||||
|
embedding_key: str = "embedding",
|
||||||
|
relevance_score_fn: str = "cosine",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
collection: MongoDB collection to add the texts to.
|
||||||
|
embedding: Text embedding model to use.
|
||||||
|
text_key: MongoDB field that will contain the text for each
|
||||||
|
document.
|
||||||
|
defaults to 'text'
|
||||||
|
embedding_key: MongoDB field that will contain the embedding for
|
||||||
|
each document.
|
||||||
|
defaults to 'embedding'
|
||||||
|
index_name: Name of the Atlas Search index.
|
||||||
|
defaults to 'default'
|
||||||
|
relevance_score_fn: The similarity score used for the index.
|
||||||
|
defaults to 'cosine'
|
||||||
|
Currently supported: 'euclidean', 'cosine', and 'dotProduct'.
|
||||||
|
"""
|
||||||
|
self._collection = collection
|
||||||
|
self._embedding = embedding
|
||||||
|
self._index_name = index_name
|
||||||
|
self._text_key = text_key
|
||||||
|
self._embedding_key = embedding_key
|
||||||
|
self._relevance_score_fn = relevance_score_fn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embeddings(self) -> Embeddings:
|
||||||
|
return self._embedding
|
||||||
|
|
||||||
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
|
scoring: dict[str, Callable] = {
|
||||||
|
"euclidean": self._euclidean_relevance_score_fn,
|
||||||
|
"dotProduct": self._max_inner_product_relevance_score_fn,
|
||||||
|
"cosine": self._cosine_relevance_score_fn,
|
||||||
|
}
|
||||||
|
if self._relevance_score_fn in scoring:
|
||||||
|
return scoring[self._relevance_score_fn]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"No relevance score function for ${self._relevance_score_fn}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_connection_string(
|
||||||
|
cls,
|
||||||
|
connection_string: str,
|
||||||
|
namespace: str,
|
||||||
|
embedding: Embeddings,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> MongoDBAtlasVectorSearch:
|
||||||
|
"""Construct a `MongoDB Atlas Vector Search` vector store
|
||||||
|
from a MongoDB connection URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_string: A valid MongoDB connection URI.
|
||||||
|
namespace: A valid MongoDB namespace (database and collection).
|
||||||
|
embedding: The text embedding model to use for the vector store.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new MongoDBAtlasVectorSearch instance.
|
||||||
|
|
||||||
|
"""
|
||||||
|
client: MongoClient = MongoClient(
|
||||||
|
connection_string,
|
||||||
|
driver=DriverInfo(name="Langchain", version=version("langchain")),
|
||||||
|
)
|
||||||
|
db_name, collection_name = namespace.split(".")
|
||||||
|
collection = client[db_name][collection_name]
|
||||||
|
return cls(collection, embedding, **kwargs)
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List:
|
||||||
|
"""Run more texts through the embeddings and add to the vectorstore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Iterable of strings to add to the vectorstore.
|
||||||
|
metadatas: Optional list of metadatas associated with the texts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ids from adding the texts into the vectorstore.
|
||||||
|
"""
|
||||||
|
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
|
||||||
|
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
|
||||||
|
texts_batch = []
|
||||||
|
metadatas_batch = []
|
||||||
|
result_ids = []
|
||||||
|
for i, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
||||||
|
texts_batch.append(text)
|
||||||
|
metadatas_batch.append(metadata)
|
||||||
|
if (i + 1) % batch_size == 0:
|
||||||
|
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
|
||||||
|
texts_batch = []
|
||||||
|
metadatas_batch = []
|
||||||
|
if texts_batch:
|
||||||
|
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
|
||||||
|
return result_ids
|
||||||
|
|
||||||
|
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
# Embed and create the documents
|
||||||
|
embeddings = self._embedding.embed_documents(texts)
|
||||||
|
to_insert = [
|
||||||
|
{self._text_key: t, self._embedding_key: embedding, **m}
|
||||||
|
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||||
|
]
|
||||||
|
# insert the documents in MongoDB Atlas
|
||||||
|
insert_result = self._collection.insert_many(to_insert) # type: ignore
|
||||||
|
return insert_result.inserted_ids
|
||||||
|
|
||||||
|
def _similarity_search_with_score(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
pre_filter: Optional[Dict] = None,
|
||||||
|
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
params = {
|
||||||
|
"queryVector": embedding,
|
||||||
|
"path": self._embedding_key,
|
||||||
|
"numCandidates": k * 10,
|
||||||
|
"limit": k,
|
||||||
|
"index": self._index_name,
|
||||||
|
}
|
||||||
|
if pre_filter:
|
||||||
|
params["filter"] = pre_filter
|
||||||
|
query = {"$vectorSearch": params}
|
||||||
|
|
||||||
|
pipeline = [
|
||||||
|
query,
|
||||||
|
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||||
|
]
|
||||||
|
if post_filter_pipeline is not None:
|
||||||
|
pipeline.extend(post_filter_pipeline)
|
||||||
|
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
||||||
|
docs = []
|
||||||
|
for res in cursor:
|
||||||
|
text = res.pop(self._text_key)
|
||||||
|
score = res.pop("score")
|
||||||
|
docs.append((Document(page_content=text, metadata=res), score))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
pre_filter: Optional[Dict] = None,
|
||||||
|
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return MongoDB documents most similar to the given query and their scores.
|
||||||
|
|
||||||
|
Uses the vectorSearch operator available in MongoDB Atlas Search.
|
||||||
|
For more: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: (Optional) number of documents to return. Defaults to 4.
|
||||||
|
pre_filter: (Optional) dictionary of argument(s) to prefilter document
|
||||||
|
fields on.
|
||||||
|
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
|
||||||
|
following the vectorSearch stage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of documents most similar to the query and their scores.
|
||||||
|
"""
|
||||||
|
embedding = self._embedding.embed_query(query)
|
||||||
|
docs = self._similarity_search_with_score(
|
||||||
|
embedding,
|
||||||
|
k=k,
|
||||||
|
pre_filter=pre_filter,
|
||||||
|
post_filter_pipeline=post_filter_pipeline,
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
pre_filter: Optional[Dict] = None,
|
||||||
|
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return MongoDB documents most similar to the given query.
|
||||||
|
|
||||||
|
Uses the vectorSearch operator available in MongoDB Atlas Search.
|
||||||
|
For more: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: (Optional) number of documents to return. Defaults to 4.
|
||||||
|
pre_filter: (Optional) dictionary of argument(s) to prefilter document
|
||||||
|
fields on.
|
||||||
|
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
|
||||||
|
following the vectorSearch stage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of documents most similar to the query and their scores.
|
||||||
|
"""
|
||||||
|
additional = kwargs.get("additional")
|
||||||
|
docs_and_scores = self.similarity_search_with_score(
|
||||||
|
query,
|
||||||
|
k=k,
|
||||||
|
pre_filter=pre_filter,
|
||||||
|
post_filter_pipeline=post_filter_pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
if additional and "similarity_score" in additional:
|
||||||
|
for doc, score in docs_and_scores:
|
||||||
|
doc.metadata["score"] = score
|
||||||
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
pre_filter: Optional[Dict] = None,
|
||||||
|
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return documents selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: (Optional) number of documents to return. Defaults to 4.
|
||||||
|
fetch_k: (Optional) number of documents to fetch before passing to MMR
|
||||||
|
algorithm. Defaults to 20.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
pre_filter: (Optional) dictionary of argument(s) to prefilter on document
|
||||||
|
fields.
|
||||||
|
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
|
||||||
|
following the vectorSearch stage.
|
||||||
|
Returns:
|
||||||
|
List of documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
query_embedding = self._embedding.embed_query(query)
|
||||||
|
docs = self._similarity_search_with_score(
|
||||||
|
query_embedding,
|
||||||
|
k=fetch_k,
|
||||||
|
pre_filter=pre_filter,
|
||||||
|
post_filter_pipeline=post_filter_pipeline,
|
||||||
|
)
|
||||||
|
mmr_doc_indexes = maximal_marginal_relevance(
|
||||||
|
np.array(query_embedding),
|
||||||
|
[doc.metadata[self._embedding_key] for doc, _ in docs],
|
||||||
|
k=k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
|
||||||
|
return mmr_docs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls,
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[Dict]] = None,
|
||||||
|
collection: Optional[Collection[MongoDBDocumentType]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> MongoDBAtlasVectorSearch:
|
||||||
|
"""Construct a `MongoDB Atlas Vector Search` vector store from raw documents.
|
||||||
|
|
||||||
|
This is a user-friendly interface that:
|
||||||
|
1. Embeds documents.
|
||||||
|
2. Adds the documents to a provided MongoDB Atlas Vector Search index
|
||||||
|
(Lucene)
|
||||||
|
|
||||||
|
This is intended to be a quick way to get started.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
from pymongo import MongoClient
|
||||||
|
|
||||||
|
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
|
||||||
|
from langchain_community.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||||
|
collection = mongo_client["<db_name>"]["<collection_name>"]
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
embeddings,
|
||||||
|
metadatas=metadatas,
|
||||||
|
collection=collection
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
if collection is None:
|
||||||
|
raise ValueError("Must provide 'collection' named parameter.")
|
||||||
|
vectorstore = cls(collection, embedding, **kwargs)
|
||||||
|
vectorstore.add_texts(texts, metadatas=metadatas)
|
||||||
|
return vectorstore
|
||||||
|
|
||||||
|
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||||
|
"""Delete by ObjectId or other criteria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids to delete.
|
||||||
|
**kwargs: Other keyword arguments that subclasses might use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[bool]: True if deletion is successful,
|
||||||
|
False otherwise, None if not implemented.
|
||||||
|
"""
|
||||||
|
search_params: dict[str, Any] = {}
|
||||||
|
if ids:
|
||||||
|
search_params[self._text_key]["$in"] = ids
|
||||||
|
|
||||||
|
return self._collection.delete_many({**search_params, **kwargs}).acknowledged
|
||||||
|
|
||||||
|
async def adelete(
|
||||||
|
self, ids: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
) -> Optional[bool]:
|
||||||
|
"""Delete by vector ID or other criteria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids to delete.
|
||||||
|
**kwargs: Other keyword arguments that subclasses might use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[bool]: True if deletion is successful,
|
||||||
|
False otherwise, None if not implemented.
|
||||||
|
"""
|
||||||
|
return await run_in_executor(None, self.delete, ids=ids, **kwargs)
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
pre_filter: Optional[Dict] = None,
|
||||||
|
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]: # type: ignore
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
pre_filter: (Optional) dictionary of argument(s) to prefilter on document
|
||||||
|
fields.
|
||||||
|
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
|
||||||
|
following the vectorSearch stage.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
docs = self._similarity_search_with_score(
|
||||||
|
embedding,
|
||||||
|
k=fetch_k,
|
||||||
|
pre_filter=pre_filter,
|
||||||
|
post_filter_pipeline=post_filter_pipeline,
|
||||||
|
)
|
||||||
|
mmr_doc_indexes = maximal_marginal_relevance(
|
||||||
|
np.array(embedding),
|
||||||
|
[doc.metadata[self._embedding_key] for doc, _ in docs],
|
||||||
|
k=k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
|
||||||
|
return mmr_docs
|
||||||
|
|
||||||
|
async def amax_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance."""
|
||||||
|
return await run_in_executor(
|
||||||
|
None,
|
||||||
|
self.max_marginal_relevance_search_by_vector,
|
||||||
|
embedding,
|
||||||
|
k=k,
|
||||||
|
fetch_k=fetch_k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
**kwargs,
|
||||||
|
)
|
1036
libs/partners/mongodb/poetry.lock
generated
Normal file
1036
libs/partners/mongodb/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
93
libs/partners/mongodb/pyproject.toml
Normal file
93
libs/partners/mongodb/pyproject.toml
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "langchain-mongodb"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "An integration package connecting MongoDB and LangChain"
|
||||||
|
authors = []
|
||||||
|
readme = "README.md"
|
||||||
|
repository = "https://github.com/langchain-ai/langchain"
|
||||||
|
license = "MIT"
|
||||||
|
|
||||||
|
[tool.poetry.urls]
|
||||||
|
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/mongodb"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.8.1,<4.0"
|
||||||
|
pymongo = ">=4.6.1,<5.0"
|
||||||
|
langchain-core = "^0.1"
|
||||||
|
numpy = "^1"
|
||||||
|
|
||||||
|
[tool.poetry.group.test]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.test.dependencies]
|
||||||
|
pytest = "^7.3.0"
|
||||||
|
freezegun = "^1.2.2"
|
||||||
|
pytest-mock = "^3.10.0"
|
||||||
|
syrupy = "^4.0.2"
|
||||||
|
pytest-watcher = "^0.3.4"
|
||||||
|
pytest-asyncio = "^0.21.1"
|
||||||
|
langchain-core = { path = "../../core", develop = true }
|
||||||
|
|
||||||
|
[tool.poetry.group.codespell]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.codespell.dependencies]
|
||||||
|
codespell = "^2.2.0"
|
||||||
|
|
||||||
|
[tool.poetry.group.test_integration]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.test_integration.dependencies]
|
||||||
|
|
||||||
|
[tool.poetry.group.lint]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.lint.dependencies]
|
||||||
|
ruff = "^0.1.5"
|
||||||
|
|
||||||
|
[tool.poetry.group.typing.dependencies]
|
||||||
|
mypy = "^0.991"
|
||||||
|
langchain-core = { path = "../../core", develop = true }
|
||||||
|
|
||||||
|
[tool.poetry.group.dev]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
langchain-core = { path = "../../core", develop = true }
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
disallow_untyped_defs = "True"
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
omit = ["tests/*"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
# --strict-markers will raise errors on unknown marks.
|
||||||
|
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||||
|
#
|
||||||
|
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||||
|
# --strict-config any warnings encountered while parsing the `pytest`
|
||||||
|
# section of the configuration file raise errors.
|
||||||
|
#
|
||||||
|
# https://github.com/tophat/syrupy
|
||||||
|
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||||
|
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||||
|
# Registering custom markers.
|
||||||
|
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||||
|
markers = [
|
||||||
|
"requires: mark tests as requiring a specific library",
|
||||||
|
"asyncio: mark tests as requiring asyncio",
|
||||||
|
"compile: mark placeholder test used to compile integration tests without running them",
|
||||||
|
]
|
||||||
|
asyncio_mode = "auto"
|
17
libs/partners/mongodb/scripts/check_imports.py
Normal file
17
libs/partners/mongodb/scripts/check_imports.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from importlib.machinery import SourceFileLoader
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
files = sys.argv[1:]
|
||||||
|
has_failure = False
|
||||||
|
for file in files:
|
||||||
|
try:
|
||||||
|
SourceFileLoader("x", file).load_module()
|
||||||
|
except Exception:
|
||||||
|
has_faillure = True
|
||||||
|
print(file)
|
||||||
|
traceback.print_exc()
|
||||||
|
print()
|
||||||
|
|
||||||
|
sys.exit(1 if has_failure else 0)
|
27
libs/partners/mongodb/scripts/check_pydantic.sh
Executable file
27
libs/partners/mongodb/scripts/check_pydantic.sh
Executable file
@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||||
|
# in tracked files within a Git repository.
|
||||||
|
#
|
||||||
|
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||||
|
|
||||||
|
# Check if a path argument is provided
|
||||||
|
if [ $# -ne 1 ]; then
|
||||||
|
echo "Usage: $0 /path/to/repository"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
repository_path="$1"
|
||||||
|
|
||||||
|
# Search for lines matching the pattern within the specified repository
|
||||||
|
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||||
|
|
||||||
|
# Check if any matching lines were found
|
||||||
|
if [ -n "$result" ]; then
|
||||||
|
echo "ERROR: The following lines need to be updated:"
|
||||||
|
echo "$result"
|
||||||
|
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||||
|
echo "For example, replace 'from pydantic import BaseModel'"
|
||||||
|
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||||
|
exit 1
|
||||||
|
fi
|
17
libs/partners/mongodb/scripts/lint_imports.sh
Executable file
17
libs/partners/mongodb/scripts/lint_imports.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
# Initialize a variable to keep track of errors
|
||||||
|
errors=0
|
||||||
|
|
||||||
|
# make sure not importing from langchain or langchain_experimental
|
||||||
|
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||||
|
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||||
|
|
||||||
|
# Decide on an exit status based on the errors
|
||||||
|
if [ "$errors" -gt 0 ]; then
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
exit 0
|
||||||
|
fi
|
0
libs/partners/mongodb/tests/__init__.py
Normal file
0
libs/partners/mongodb/tests/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.compile
|
||||||
|
def test_placeholder() -> None:
|
||||||
|
"""Used for compiling integration tests without running any real tests."""
|
||||||
|
pass
|
@ -0,0 +1,170 @@
|
|||||||
|
"""Test MongoDB Atlas Vector Search functionality."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from time import sleep
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
|
||||||
|
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||||
|
from tests.utils import ConsistentFakeEmbeddings
|
||||||
|
|
||||||
|
INDEX_NAME = "langchain-test-index"
|
||||||
|
NAMESPACE = "langchain_test_db.langchain_test_collection"
|
||||||
|
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
||||||
|
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
||||||
|
DIMENSIONS = 1536
|
||||||
|
TIMEOUT = 10.0
|
||||||
|
INTERVAL = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch):
|
||||||
|
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
|
||||||
|
"""Patched insert_texts that waits for data to be indexed before returning"""
|
||||||
|
ids = super()._insert_texts(texts, metadatas)
|
||||||
|
timeout = TIMEOUT
|
||||||
|
while len(ids) != self.similarity_search("sandwich") and timeout >= 0:
|
||||||
|
sleep(INTERVAL)
|
||||||
|
timeout -= INTERVAL
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_collection() -> Collection:
|
||||||
|
test_client: MongoClient = MongoClient(CONNECTION_STRING)
|
||||||
|
return test_client[DB_NAME][COLLECTION_NAME]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def collection() -> Collection:
|
||||||
|
return get_collection()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMongoDBAtlasVectorSearch:
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls) -> None:
|
||||||
|
# insure the test collection is empty
|
||||||
|
collection = get_collection()
|
||||||
|
assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls) -> None:
|
||||||
|
collection = get_collection()
|
||||||
|
# delete all the documents in the collection
|
||||||
|
collection.delete_many({}) # type: ignore[index]
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup(self) -> None:
|
||||||
|
collection = get_collection()
|
||||||
|
# delete all the documents in the collection
|
||||||
|
collection.delete_many({}) # type: ignore[index]
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def embedding_openai(self) -> Embeddings:
|
||||||
|
return ConsistentFakeEmbeddings(DIMENSIONS)
|
||||||
|
|
||||||
|
def test_from_documents(
|
||||||
|
self, embedding_openai: Embeddings, collection: Any
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
documents = [
|
||||||
|
Document(page_content="Dogs are tough.", metadata={"a": 1}),
|
||||||
|
Document(page_content="Cats have fluff.", metadata={"b": 1}),
|
||||||
|
Document(page_content="What is a sandwich?", metadata={"c": 1}),
|
||||||
|
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
|
||||||
|
]
|
||||||
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
||||||
|
documents,
|
||||||
|
embedding_openai,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
# sleep(5) # waits for mongot to update Lucene's index
|
||||||
|
output = vectorstore.similarity_search("Sandwich", k=1)
|
||||||
|
assert len(output) == 1
|
||||||
|
# Check for the presence of the metadata key
|
||||||
|
assert any([key.page_content == output[0].page_content for key in documents])
|
||||||
|
|
||||||
|
def test_from_texts(self, embedding_openai: Embeddings, collection: Any) -> None:
|
||||||
|
texts = [
|
||||||
|
"Dogs are tough.",
|
||||||
|
"Cats have fluff.",
|
||||||
|
"What is a sandwich?",
|
||||||
|
"That fence is purple.",
|
||||||
|
]
|
||||||
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_openai,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
# sleep(5) # waits for mongot to update Lucene's index
|
||||||
|
output = vectorstore.similarity_search("Sandwich", k=1)
|
||||||
|
assert len(output) == 1
|
||||||
|
|
||||||
|
def test_from_texts_with_metadatas(
|
||||||
|
self, embedding_openai: Embeddings, collection: Any
|
||||||
|
) -> None:
|
||||||
|
texts = [
|
||||||
|
"Dogs are tough.",
|
||||||
|
"Cats have fluff.",
|
||||||
|
"What is a sandwich?",
|
||||||
|
"The fence is purple.",
|
||||||
|
]
|
||||||
|
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||||
|
metakeys = ["a", "b", "c", "d", "e"]
|
||||||
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_openai,
|
||||||
|
metadatas=metadatas,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
# sleep(5) # waits for mongot to update Lucene's index
|
||||||
|
output = vectorstore.similarity_search("Sandwich", k=1)
|
||||||
|
assert len(output) == 1
|
||||||
|
# Check for the presence of the metadata key
|
||||||
|
assert any([key in output[0].metadata for key in metakeys])
|
||||||
|
|
||||||
|
def test_from_texts_with_metadatas_and_pre_filter(
|
||||||
|
self, embedding_openai: Embeddings, collection: Any
|
||||||
|
) -> None:
|
||||||
|
texts = [
|
||||||
|
"Dogs are tough.",
|
||||||
|
"Cats have fluff.",
|
||||||
|
"What is a sandwich?",
|
||||||
|
"The fence is purple.",
|
||||||
|
]
|
||||||
|
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||||
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_openai,
|
||||||
|
metadatas=metadatas,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
# sleep(5) # waits for mongot to update Lucene's index
|
||||||
|
output = vectorstore.similarity_search(
|
||||||
|
"Sandwich", k=1, pre_filter={"c": {"$lte": 0}}
|
||||||
|
)
|
||||||
|
assert output == []
|
||||||
|
|
||||||
|
def test_mmr(self, embedding_openai: Embeddings, collection: Any) -> None:
|
||||||
|
texts = ["foo", "foo", "fou", "foy"]
|
||||||
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_openai,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
# sleep(5) # waits for mongot to update Lucene's index
|
||||||
|
query = "foo"
|
||||||
|
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
|
||||||
|
assert len(output) == len(texts)
|
||||||
|
assert output[0].page_content == "foo"
|
||||||
|
assert output[1].page_content != "foo"
|
0
libs/partners/mongodb/tests/unit_tests/__init__.py
Normal file
0
libs/partners/mongodb/tests/unit_tests/__init__.py
Normal file
9
libs/partners/mongodb/tests/unit_tests/test_imports.py
Normal file
9
libs/partners/mongodb/tests/unit_tests/test_imports.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from langchain_mongodb import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"MongoDBAtlasVectorSearch",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
224
libs/partners/mongodb/tests/unit_tests/test_vectorstores.py
Normal file
224
libs/partners/mongodb/tests/unit_tests/test_vectorstores.py
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
import uuid
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.results import DeleteResult, InsertManyResult
|
||||||
|
|
||||||
|
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||||
|
from tests.utils import ConsistentFakeEmbeddings
|
||||||
|
|
||||||
|
INDEX_NAME = "langchain-test-index"
|
||||||
|
NAMESPACE = "langchain_test_db.langchain_test_collection"
|
||||||
|
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
||||||
|
|
||||||
|
|
||||||
|
class MockCollection(Collection):
|
||||||
|
"""Mocked Mongo Collection"""
|
||||||
|
|
||||||
|
_aggregate_result: List[Any]
|
||||||
|
_insert_result: Optional[InsertManyResult]
|
||||||
|
_data: List[Any]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._data = []
|
||||||
|
self._aggregate_result = []
|
||||||
|
self._insert_result = None
|
||||||
|
|
||||||
|
def delete_many(self, *args, **kwargs) -> DeleteResult: # type: ignore
|
||||||
|
old_len = len(self._data)
|
||||||
|
self._data = []
|
||||||
|
return DeleteResult({"n": old_len}, acknowledged=True)
|
||||||
|
|
||||||
|
def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore
|
||||||
|
mongodb_inserts = [
|
||||||
|
{"_id": str(uuid.uuid4()), "score": 1, **insert} for insert in to_insert
|
||||||
|
]
|
||||||
|
self._data.extend(mongodb_inserts)
|
||||||
|
return self._insert_result or InsertManyResult(
|
||||||
|
[k["_id"] for k in mongodb_inserts], acknowledged=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def aggregate(self, *args, **kwargs) -> List[Any]: # type: ignore
|
||||||
|
return deepcopy(self._aggregate_result)
|
||||||
|
|
||||||
|
def count_documents(self, *args, **kwargs) -> int: # type: ignore
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "FakeCollection"
|
||||||
|
|
||||||
|
|
||||||
|
def get_collection() -> MockCollection:
|
||||||
|
return MockCollection()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def collection() -> MockCollection:
|
||||||
|
return get_collection()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def embedding_openai() -> Embeddings:
|
||||||
|
return ConsistentFakeEmbeddings()
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialization(collection: Collection, embedding_openai: Embeddings) -> None:
|
||||||
|
"""Test initialization of vector store class"""
|
||||||
|
assert MongoDBAtlasVectorSearch(collection, embedding_openai)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_from_texts(collection: Collection, embedding_openai: Embeddings) -> None:
|
||||||
|
"""Test from_texts operation on an empty list"""
|
||||||
|
assert MongoDBAtlasVectorSearch.from_texts(
|
||||||
|
[], embedding_openai, collection=collection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMongoDBAtlasVectorSearch:
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls) -> None:
|
||||||
|
# ensure the test collection is empty
|
||||||
|
collection = get_collection()
|
||||||
|
assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls) -> None:
|
||||||
|
collection = get_collection()
|
||||||
|
# delete all the documents in the collection
|
||||||
|
collection.delete_many({}) # type: ignore[index]
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup(self) -> None:
|
||||||
|
collection = get_collection()
|
||||||
|
# delete all the documents in the collection
|
||||||
|
collection.delete_many({}) # type: ignore[index]
|
||||||
|
|
||||||
|
def _validate_search(
|
||||||
|
self,
|
||||||
|
vectorstore: MongoDBAtlasVectorSearch,
|
||||||
|
collection: MockCollection,
|
||||||
|
search_term: str = "sandwich",
|
||||||
|
page_content: str = "What is a sandwich?",
|
||||||
|
metadata: Optional[Any] = 1,
|
||||||
|
) -> None:
|
||||||
|
collection._aggregate_result = list(
|
||||||
|
filter(
|
||||||
|
lambda x: search_term.lower() in x[vectorstore._text_key].lower(),
|
||||||
|
collection._data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output = vectorstore.similarity_search("", k=1)
|
||||||
|
assert output[0].page_content == page_content
|
||||||
|
assert output[0].metadata.get("c") == metadata
|
||||||
|
|
||||||
|
def test_from_documents(
|
||||||
|
self, embedding_openai: Embeddings, collection: MockCollection
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
documents = [
|
||||||
|
Document(page_content="Dogs are tough.", metadata={"a": 1}),
|
||||||
|
Document(page_content="Cats have fluff.", metadata={"b": 1}),
|
||||||
|
Document(page_content="What is a sandwich?", metadata={"c": 1}),
|
||||||
|
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
|
||||||
|
]
|
||||||
|
vectorstore = MongoDBAtlasVectorSearch.from_documents(
|
||||||
|
documents,
|
||||||
|
embedding_openai,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
self._validate_search(
|
||||||
|
vectorstore, collection, metadata=documents[2].metadata["c"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_from_texts(
|
||||||
|
self, embedding_openai: Embeddings, collection: MockCollection
|
||||||
|
) -> None:
|
||||||
|
texts = [
|
||||||
|
"Dogs are tough.",
|
||||||
|
"Cats have fluff.",
|
||||||
|
"What is a sandwich?",
|
||||||
|
"That fence is purple.",
|
||||||
|
]
|
||||||
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_openai,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
self._validate_search(vectorstore, collection, metadata=None)
|
||||||
|
|
||||||
|
def test_from_texts_with_metadatas(
|
||||||
|
self, embedding_openai: Embeddings, collection: MockCollection
|
||||||
|
) -> None:
|
||||||
|
texts = [
|
||||||
|
"Dogs are tough.",
|
||||||
|
"Cats have fluff.",
|
||||||
|
"What is a sandwich?",
|
||||||
|
"The fence is purple.",
|
||||||
|
]
|
||||||
|
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||||
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_openai,
|
||||||
|
metadatas=metadatas,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
self._validate_search(vectorstore, collection, metadata=metadatas[2]["c"])
|
||||||
|
|
||||||
|
def test_from_texts_with_metadatas_and_pre_filter(
|
||||||
|
self, embedding_openai: Embeddings, collection: MockCollection
|
||||||
|
) -> None:
|
||||||
|
texts = [
|
||||||
|
"Dogs are tough.",
|
||||||
|
"Cats have fluff.",
|
||||||
|
"What is a sandwich?",
|
||||||
|
"The fence is purple.",
|
||||||
|
]
|
||||||
|
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||||
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_openai,
|
||||||
|
metadatas=metadatas,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
collection._aggregate_result = list(
|
||||||
|
filter(
|
||||||
|
lambda x: "sandwich" in x[vectorstore._text_key].lower()
|
||||||
|
and x.get("c") < 0,
|
||||||
|
collection._data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output = vectorstore.similarity_search(
|
||||||
|
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
|
||||||
|
)
|
||||||
|
assert output == []
|
||||||
|
|
||||||
|
def test_mmr(
|
||||||
|
self, embedding_openai: Embeddings, collection: MockCollection
|
||||||
|
) -> None:
|
||||||
|
texts = ["foo", "foo", "fou", "foy"]
|
||||||
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_openai,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
query = "foo"
|
||||||
|
self._validate_search(
|
||||||
|
vectorstore,
|
||||||
|
collection,
|
||||||
|
search_term=query[0:2],
|
||||||
|
page_content=query,
|
||||||
|
metadata=None,
|
||||||
|
)
|
||||||
|
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
|
||||||
|
assert len(output) == len(texts)
|
||||||
|
assert output[0].page_content == "foo"
|
||||||
|
assert output[1].page_content != "foo"
|
36
libs/partners/mongodb/tests/utils.py
Normal file
36
libs/partners/mongodb/tests/utils.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class ConsistentFakeEmbeddings(Embeddings):
|
||||||
|
"""Fake embeddings functionality for testing."""
|
||||||
|
|
||||||
|
def __init__(self, dimensionality: int = 10) -> None:
|
||||||
|
self.known_texts: List[str] = []
|
||||||
|
self.dimensionality = dimensionality
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Return consistent embeddings for each text seen so far."""
|
||||||
|
out_vectors = []
|
||||||
|
for text in texts:
|
||||||
|
if text not in self.known_texts:
|
||||||
|
self.known_texts.append(text)
|
||||||
|
vector = [float(1.0)] * (self.dimensionality - 1) + [
|
||||||
|
float(self.known_texts.index(text))
|
||||||
|
]
|
||||||
|
out_vectors.append(vector)
|
||||||
|
return out_vectors
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||||
|
one if the text is unknown."""
|
||||||
|
return self.embed_documents([text])[0]
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
return self.embed_documents(texts)
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
return self.embed_query(text)
|
Loading…
Reference in New Issue
Block a user