mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 22:53:30 +00:00
community[minor]: added a feature to filter documents in Mongoloader (#18253)
"community: added a feature to filter documents in Mongoloader" - **Description:** added a feature to filter documents in Mongoloader - **Feature:** the feature #18251 - **Dependencies:** No - **Twitter handle:** https://twitter.com/im_Kushagra
This commit is contained in:
parent
c0bdd4d45b
commit
b1f22bf76c
@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
@ -19,6 +19,7 @@ class MongodbLoader(BaseLoader):
|
|||||||
collection_name: str,
|
collection_name: str,
|
||||||
*,
|
*,
|
||||||
filter_criteria: Optional[Dict] = None,
|
filter_criteria: Optional[Dict] = None,
|
||||||
|
field_names: Optional[Sequence[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
@ -38,6 +39,7 @@ class MongodbLoader(BaseLoader):
|
|||||||
self.client = AsyncIOMotorClient(connection_string)
|
self.client = AsyncIOMotorClient(connection_string)
|
||||||
self.db_name = db_name
|
self.db_name = db_name
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
|
self.field_names = field_names
|
||||||
self.filter_criteria = filter_criteria or {}
|
self.filter_criteria = filter_criteria or {}
|
||||||
|
|
||||||
self.db = self.client.get_database(db_name)
|
self.db = self.client.get_database(db_name)
|
||||||
@ -61,17 +63,32 @@ class MongodbLoader(BaseLoader):
|
|||||||
"""Load data into Document objects."""
|
"""Load data into Document objects."""
|
||||||
result = []
|
result = []
|
||||||
total_docs = await self.collection.count_documents(self.filter_criteria)
|
total_docs = await self.collection.count_documents(self.filter_criteria)
|
||||||
async for doc in self.collection.find(self.filter_criteria):
|
|
||||||
|
# Construct the projection dictionary if field_names are specified
|
||||||
|
projection = (
|
||||||
|
{field: 1 for field in self.field_names} if self.field_names else None
|
||||||
|
)
|
||||||
|
|
||||||
|
async for doc in self.collection.find(self.filter_criteria, projection):
|
||||||
metadata = {
|
metadata = {
|
||||||
"database": self.db_name,
|
"database": self.db_name,
|
||||||
"collection": self.collection_name,
|
"collection": self.collection_name,
|
||||||
}
|
}
|
||||||
result.append(Document(page_content=str(doc), metadata=metadata))
|
|
||||||
|
# Extract text content from filtered fields or use the entire document
|
||||||
|
if self.field_names is not None:
|
||||||
|
fields = {name: doc[name] for name in self.field_names}
|
||||||
|
texts = [str(value) for value in fields.values()]
|
||||||
|
text = " ".join(texts)
|
||||||
|
else:
|
||||||
|
text = str(doc)
|
||||||
|
|
||||||
|
result.append(Document(page_content=text, metadata=metadata))
|
||||||
|
|
||||||
if len(result) != total_docs:
|
if len(result) != total_docs:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Only partial collection of documents returned. Loaded {len(result)} "
|
f"Only partial collection of documents returned. "
|
||||||
f"docs, expected {total_docs}."
|
f"Loaded {len(result)} docs, expected {total_docs}."
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
Loading…
Reference in New Issue
Block a user