community[minor]: added a feature to filter documents in Mongoloader (#18253)

"community: added a feature to filter documents in Mongoloader"
- **Description:** added a feature to filter documents in Mongoloader
    - **Feature:** the feature #18251
    - **Dependencies:** No
    - **Twitter handle:** https://twitter.com/im_Kushagra
This commit is contained in:
Kushagra 2024-03-09 01:36:35 +05:30 committed by GitHub
parent c0bdd4d45b
commit b1f22bf76c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import asyncio
import logging
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Sequence
from langchain_core.documents import Document
@ -19,6 +19,7 @@ class MongodbLoader(BaseLoader):
collection_name: str,
*,
filter_criteria: Optional[Dict] = None,
field_names: Optional[Sequence[str]] = None,
) -> None:
try:
from motor.motor_asyncio import AsyncIOMotorClient
@ -38,6 +39,7 @@ class MongodbLoader(BaseLoader):
self.client = AsyncIOMotorClient(connection_string)
self.db_name = db_name
self.collection_name = collection_name
self.field_names = field_names
self.filter_criteria = filter_criteria or {}
self.db = self.client.get_database(db_name)
@ -61,17 +63,32 @@ class MongodbLoader(BaseLoader):
"""Load data into Document objects."""
result = []
total_docs = await self.collection.count_documents(self.filter_criteria)
async for doc in self.collection.find(self.filter_criteria):
# Construct the projection dictionary if field_names are specified
projection = (
{field: 1 for field in self.field_names} if self.field_names else None
)
async for doc in self.collection.find(self.filter_criteria, projection):
metadata = {
"database": self.db_name,
"collection": self.collection_name,
}
result.append(Document(page_content=str(doc), metadata=metadata))
# Extract text content from filtered fields or use the entire document
if self.field_names is not None:
fields = {name: doc[name] for name in self.field_names}
texts = [str(value) for value in fields.values()]
text = " ".join(texts)
else:
text = str(doc)
result.append(Document(page_content=text, metadata=metadata))
if len(result) != total_docs:
logger.warning(
f"Only partial collection of documents returned. Loaded {len(result)} "
f"docs, expected {total_docs}."
f"Only partial collection of documents returned. "
f"Loaded {len(result)} docs, expected {total_docs}."
)
return result