mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 01:13:48 +00:00
removing client+namespace in favor of collection (#5610)
removing client+namespace in favor of collection for an easier instantiation and to be similar to the typescript library @dev2049
This commit is contained in:
parent
ad09367a92
commit
92f218207b
@ -118,15 +118,14 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"db_name = \"lanchain_db\"\n",
|
"db_name = \"lanchain_db\"\n",
|
||||||
"collection_name = \"langchain_col\"\n",
|
"collection_name = \"langchain_col\"\n",
|
||||||
"namespace = f\"{db_name}.{collection_name}\"\n",
|
"collection = client[db_name][collection_name]\n",
|
||||||
"index_name = \"langchain_demo\"\n",
|
"index_name = \"langchain_demo\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# insert the documents in MongoDB Atlas with their embedding\n",
|
"# insert the documents in MongoDB Atlas with their embedding\n",
|
||||||
"docsearch = MongoDBAtlasVectorSearch.from_documents(\n",
|
"docsearch = MongoDBAtlasVectorSearch.from_documents(\n",
|
||||||
" docs,\n",
|
" docs,\n",
|
||||||
" embeddings,\n",
|
" embeddings,\n",
|
||||||
" client=client,\n",
|
" collection=collection,\n",
|
||||||
" namespace=namespace,\n",
|
|
||||||
" index_name=index_name\n",
|
" index_name=index_name\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -10,6 +10,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -18,7 +19,9 @@ from langchain.embeddings.base import Embeddings
|
|||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pymongo import MongoClient
|
from pymongo.collection import Collection
|
||||||
|
|
||||||
|
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -41,15 +44,14 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
|
||||||
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||||
namespace = "<db_name>.<collection_name>"
|
collection = mongo_client["<db_name>"]["<collection_name>"]
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
vectorstore = MongoDBAtlasVectorSearch(mongo_client, namespace, embeddings)
|
vectorstore = MongoDBAtlasVectorSearch(collection, embeddings)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: MongoClient,
|
collection: Collection[MongoDBDocumentType],
|
||||||
namespace: str,
|
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
*,
|
*,
|
||||||
index_name: str = "default",
|
index_name: str = "default",
|
||||||
@ -58,17 +60,14 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
client: MongoDB client.
|
collection: MongoDB collection to add the texts to.
|
||||||
namespace: MongoDB namespace to add the texts to.
|
|
||||||
embedding: Text embedding model to use.
|
embedding: Text embedding model to use.
|
||||||
text_key: MongoDB field that will contain the text for each
|
text_key: MongoDB field that will contain the text for each
|
||||||
document.
|
document.
|
||||||
embedding_key: MongoDB field that will contain the embedding for
|
embedding_key: MongoDB field that will contain the embedding for
|
||||||
each document.
|
each document.
|
||||||
"""
|
"""
|
||||||
self._client = client
|
self._collection = collection
|
||||||
db_name, collection_name = namespace.split(".")
|
|
||||||
self._collection = client[db_name][collection_name]
|
|
||||||
self._embedding = embedding
|
self._embedding = embedding
|
||||||
self._index_name = index_name
|
self._index_name = index_name
|
||||||
self._text_key = text_key
|
self._text_key = text_key
|
||||||
@ -90,7 +89,9 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
"`pip install pymongo`."
|
"`pip install pymongo`."
|
||||||
)
|
)
|
||||||
client: MongoClient = MongoClient(connection_string)
|
client: MongoClient = MongoClient(connection_string)
|
||||||
return cls(client, namespace, embedding, **kwargs)
|
db_name, collection_name = namespace.split(".")
|
||||||
|
collection = client[db_name][collection_name]
|
||||||
|
return cls(collection, embedding, **kwargs)
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
@ -232,8 +233,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
client: Optional[MongoClient] = None,
|
collection: Optional[Collection[MongoDBDocumentType]] = None,
|
||||||
namespace: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MongoDBAtlasVectorSearch:
|
) -> MongoDBAtlasVectorSearch:
|
||||||
"""Construct MongoDBAtlasVectorSearch wrapper from raw documents.
|
"""Construct MongoDBAtlasVectorSearch wrapper from raw documents.
|
||||||
@ -253,18 +253,17 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
client = MongoClient("<YOUR-CONNECTION-STRING>")
|
client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||||
namespace = "<db_name>.<collection_name>"
|
collection = mongo_client["<db_name>"]["<collection_name>"]
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||||
texts,
|
texts,
|
||||||
embeddings,
|
embeddings,
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
client=client,
|
collection=collection
|
||||||
namespace=namespace
|
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
if not client or not namespace:
|
if not collection:
|
||||||
raise ValueError("Must provide 'client' and 'namespace' named parameters.")
|
raise ValueError("Must provide 'collection' named parameter.")
|
||||||
vecstore = cls(client, namespace, embedding, **kwargs)
|
vecstore = cls(collection, embedding, **kwargs)
|
||||||
vecstore.add_texts(texts, metadatas=metadatas)
|
vecstore.add_texts(texts, metadatas=metadatas)
|
||||||
return vecstore
|
return vecstore
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -19,37 +19,27 @@ NAMESPACE = "langchain_test_db.langchain_test_collection"
|
|||||||
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
||||||
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
||||||
|
|
||||||
|
|
||||||
def get_test_client() -> Optional[MongoClient]:
|
|
||||||
try:
|
|
||||||
from pymongo import MongoClient
|
|
||||||
|
|
||||||
client: MongoClient = MongoClient(CONNECTION_STRING)
|
|
||||||
return client
|
|
||||||
except: # noqa: E722
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# Instantiate as constant instead of pytest fixture to prevent needing to make multiple
|
# Instantiate as constant instead of pytest fixture to prevent needing to make multiple
|
||||||
# connections.
|
# connections.
|
||||||
TEST_CLIENT = get_test_client()
|
TEST_CLIENT = MongoClient(CONNECTION_STRING)
|
||||||
|
collection = TEST_CLIENT[DB_NAME][COLLECTION_NAME]
|
||||||
|
|
||||||
|
|
||||||
class TestMongoDBAtlasVectorSearch:
|
class TestMongoDBAtlasVectorSearch:
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_class(cls) -> None:
|
def setup_class(cls) -> None:
|
||||||
# insure the test collection is empty
|
# insure the test collection is empty
|
||||||
assert TEST_CLIENT[DB_NAME][COLLECTION_NAME].count_documents({}) == 0 # type: ignore[index] # noqa: E501
|
assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def teardown_class(cls) -> None:
|
def teardown_class(cls) -> None:
|
||||||
# delete all the documents in the collection
|
# delete all the documents in the collection
|
||||||
TEST_CLIENT[DB_NAME][COLLECTION_NAME].delete_many({}) # type: ignore[index]
|
collection.delete_many({}) # type: ignore[index]
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
# delete all the documents in the collection
|
# delete all the documents in the collection
|
||||||
TEST_CLIENT[DB_NAME][COLLECTION_NAME].delete_many({}) # type: ignore[index]
|
collection.delete_many({}) # type: ignore[index]
|
||||||
|
|
||||||
def test_from_documents(self, embedding_openai: Embeddings) -> None:
|
def test_from_documents(self, embedding_openai: Embeddings) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
@ -62,8 +52,7 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
vectorstore = MongoDBAtlasVectorSearch.from_documents(
|
vectorstore = MongoDBAtlasVectorSearch.from_documents(
|
||||||
documents,
|
documents,
|
||||||
embedding_openai,
|
embedding_openai,
|
||||||
client=TEST_CLIENT,
|
collection=collection,
|
||||||
namespace=NAMESPACE,
|
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
)
|
)
|
||||||
sleep(1) # waits for mongot to update Lucene's index
|
sleep(1) # waits for mongot to update Lucene's index
|
||||||
@ -81,8 +70,7 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||||
texts,
|
texts,
|
||||||
embedding_openai,
|
embedding_openai,
|
||||||
client=TEST_CLIENT,
|
collection=collection,
|
||||||
namespace=NAMESPACE,
|
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
)
|
)
|
||||||
sleep(1) # waits for mongot to update Lucene's index
|
sleep(1) # waits for mongot to update Lucene's index
|
||||||
@ -101,8 +89,7 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
texts,
|
texts,
|
||||||
embedding_openai,
|
embedding_openai,
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
client=TEST_CLIENT,
|
collection=collection,
|
||||||
namespace=NAMESPACE,
|
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
)
|
)
|
||||||
sleep(1) # waits for mongot to update Lucene's index
|
sleep(1) # waits for mongot to update Lucene's index
|
||||||
@ -124,8 +111,7 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
texts,
|
texts,
|
||||||
embedding_openai,
|
embedding_openai,
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
client=TEST_CLIENT,
|
collection=collection,
|
||||||
namespace=NAMESPACE,
|
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
)
|
)
|
||||||
sleep(1) # waits for mongot to update Lucene's index
|
sleep(1) # waits for mongot to update Lucene's index
|
||||||
|
Loading…
Reference in New Issue
Block a user