mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 17:07:25 +00:00
removing client+namespace in favor of collection (#5610)
removing client+namespace in favor of collection for an easier instantiation and to be similar to the typescript library @dev2049
This commit is contained in:
parent
ad09367a92
commit
92f218207b
@ -118,15 +118,14 @@
|
||||
"\n",
|
||||
"db_name = \"lanchain_db\"\n",
|
||||
"collection_name = \"langchain_col\"\n",
|
||||
"namespace = f\"{db_name}.{collection_name}\"\n",
|
||||
"collection = client[db_name][collection_name]\n",
|
||||
"index_name = \"langchain_demo\"\n",
|
||||
"\n",
|
||||
"# insert the documents in MongoDB Atlas with their embedding\n",
|
||||
"docsearch = MongoDBAtlasVectorSearch.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" client=client,\n",
|
||||
" namespace=namespace,\n",
|
||||
" collection=collection,\n",
|
||||
" index_name=index_name\n",
|
||||
")\n",
|
||||
"\n",
|
||||
|
@ -10,6 +10,7 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -18,7 +19,9 @@ from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo import MongoClient
|
||||
from pymongo.collection import Collection
|
||||
|
||||
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -41,15 +44,14 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
from pymongo import MongoClient
|
||||
|
||||
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||
namespace = "<db_name>.<collection_name>"
|
||||
collection = mongo_client["<db_name>"]["<collection_name>"]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = MongoDBAtlasVectorSearch(mongo_client, namespace, embeddings)
|
||||
vectorstore = MongoDBAtlasVectorSearch(collection, embeddings)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: MongoClient,
|
||||
namespace: str,
|
||||
collection: Collection[MongoDBDocumentType],
|
||||
embedding: Embeddings,
|
||||
*,
|
||||
index_name: str = "default",
|
||||
@ -58,17 +60,14 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
client: MongoDB client.
|
||||
namespace: MongoDB namespace to add the texts to.
|
||||
collection: MongoDB collection to add the texts to.
|
||||
embedding: Text embedding model to use.
|
||||
text_key: MongoDB field that will contain the text for each
|
||||
document.
|
||||
embedding_key: MongoDB field that will contain the embedding for
|
||||
each document.
|
||||
"""
|
||||
self._client = client
|
||||
db_name, collection_name = namespace.split(".")
|
||||
self._collection = client[db_name][collection_name]
|
||||
self._collection = collection
|
||||
self._embedding = embedding
|
||||
self._index_name = index_name
|
||||
self._text_key = text_key
|
||||
@ -90,7 +89,9 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
"`pip install pymongo`."
|
||||
)
|
||||
client: MongoClient = MongoClient(connection_string)
|
||||
return cls(client, namespace, embedding, **kwargs)
|
||||
db_name, collection_name = namespace.split(".")
|
||||
collection = client[db_name][collection_name]
|
||||
return cls(collection, embedding, **kwargs)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
@ -232,8 +233,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
client: Optional[MongoClient] = None,
|
||||
namespace: Optional[str] = None,
|
||||
collection: Optional[Collection[MongoDBDocumentType]] = None,
|
||||
**kwargs: Any,
|
||||
) -> MongoDBAtlasVectorSearch:
|
||||
"""Construct MongoDBAtlasVectorSearch wrapper from raw documents.
|
||||
@ -253,18 +253,17 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||
namespace = "<db_name>.<collection_name>"
|
||||
collection = mongo_client["<db_name>"]["<collection_name>"]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embeddings,
|
||||
metadatas=metadatas,
|
||||
client=client,
|
||||
namespace=namespace
|
||||
collection=collection
|
||||
)
|
||||
"""
|
||||
if not client or not namespace:
|
||||
raise ValueError("Must provide 'client' and 'namespace' named parameters.")
|
||||
vecstore = cls(client, namespace, embedding, **kwargs)
|
||||
if not collection:
|
||||
raise ValueError("Must provide 'collection' named parameter.")
|
||||
vecstore = cls(collection, embedding, **kwargs)
|
||||
vecstore.add_texts(texts, metadatas=metadatas)
|
||||
return vecstore
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from time import sleep
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
@ -19,37 +19,27 @@ NAMESPACE = "langchain_test_db.langchain_test_collection"
|
||||
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
||||
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
||||
|
||||
|
||||
def get_test_client() -> Optional[MongoClient]:
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
|
||||
client: MongoClient = MongoClient(CONNECTION_STRING)
|
||||
return client
|
||||
except: # noqa: E722
|
||||
return None
|
||||
|
||||
|
||||
# Instantiate as constant instead of pytest fixture to prevent needing to make multiple
|
||||
# connections.
|
||||
TEST_CLIENT = get_test_client()
|
||||
TEST_CLIENT = MongoClient(CONNECTION_STRING)
|
||||
collection = TEST_CLIENT[DB_NAME][COLLECTION_NAME]
|
||||
|
||||
|
||||
class TestMongoDBAtlasVectorSearch:
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
# insure the test collection is empty
|
||||
assert TEST_CLIENT[DB_NAME][COLLECTION_NAME].count_documents({}) == 0 # type: ignore[index] # noqa: E501
|
||||
assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls) -> None:
|
||||
# delete all the documents in the collection
|
||||
TEST_CLIENT[DB_NAME][COLLECTION_NAME].delete_many({}) # type: ignore[index]
|
||||
collection.delete_many({}) # type: ignore[index]
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self) -> None:
|
||||
# delete all the documents in the collection
|
||||
TEST_CLIENT[DB_NAME][COLLECTION_NAME].delete_many({}) # type: ignore[index]
|
||||
collection.delete_many({}) # type: ignore[index]
|
||||
|
||||
def test_from_documents(self, embedding_openai: Embeddings) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
@ -62,8 +52,7 @@ class TestMongoDBAtlasVectorSearch:
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_documents(
|
||||
documents,
|
||||
embedding_openai,
|
||||
client=TEST_CLIENT,
|
||||
namespace=NAMESPACE,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
sleep(1) # waits for mongot to update Lucene's index
|
||||
@ -81,8 +70,7 @@ class TestMongoDBAtlasVectorSearch:
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
client=TEST_CLIENT,
|
||||
namespace=NAMESPACE,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
sleep(1) # waits for mongot to update Lucene's index
|
||||
@ -101,8 +89,7 @@ class TestMongoDBAtlasVectorSearch:
|
||||
texts,
|
||||
embedding_openai,
|
||||
metadatas=metadatas,
|
||||
client=TEST_CLIENT,
|
||||
namespace=NAMESPACE,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
sleep(1) # waits for mongot to update Lucene's index
|
||||
@ -124,8 +111,7 @@ class TestMongoDBAtlasVectorSearch:
|
||||
texts,
|
||||
embedding_openai,
|
||||
metadatas=metadatas,
|
||||
client=TEST_CLIENT,
|
||||
namespace=NAMESPACE,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
sleep(1) # waits for mongot to update Lucene's index
|
||||
|
Loading…
Reference in New Issue
Block a user