mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 12:25:37 +00:00
150 lines
5.4 KiB
Python
150 lines
5.4 KiB
Python
import asyncio
|
|
import logging
|
|
from typing import Dict, List, Optional, Sequence
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
from langchain_community.document_loaders.base import BaseLoader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MongodbLoader(BaseLoader):
|
|
"""Load MongoDB documents."""
|
|
|
|
def __init__(
|
|
self,
|
|
connection_string: str,
|
|
db_name: str,
|
|
collection_name: str,
|
|
*,
|
|
filter_criteria: Optional[Dict] = None,
|
|
field_names: Optional[Sequence[str]] = None,
|
|
metadata_names: Optional[Sequence[str]] = None,
|
|
include_db_collection_in_metadata: bool = True,
|
|
) -> None:
|
|
"""
|
|
Initializes the MongoDB loader with necessary database connection
|
|
details and configurations.
|
|
|
|
Args:
|
|
connection_string (str): MongoDB connection URI.
|
|
db_name (str):Name of the database to connect to.
|
|
collection_name (str): Name of the collection to fetch documents from.
|
|
filter_criteria (Optional[Dict]): MongoDB filter criteria for querying
|
|
documents.
|
|
field_names (Optional[Sequence[str]]): List of field names to retrieve
|
|
from documents.
|
|
metadata_names (Optional[Sequence[str]]): Additional metadata fields to
|
|
extract from documents.
|
|
include_db_collection_in_metadata (bool): Flag to include database and
|
|
collection names in metadata.
|
|
|
|
Raises:
|
|
ImportError: If the motor library is not installed.
|
|
ValueError: If any necessary argument is missing.
|
|
"""
|
|
try:
|
|
from motor.motor_asyncio import AsyncIOMotorClient
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Cannot import from motor, please install with `pip install motor`."
|
|
) from e
|
|
|
|
if not connection_string:
|
|
raise ValueError("connection_string must be provided.")
|
|
|
|
if not db_name:
|
|
raise ValueError("db_name must be provided.")
|
|
|
|
if not collection_name:
|
|
raise ValueError("collection_name must be provided.")
|
|
|
|
self.client = AsyncIOMotorClient(connection_string)
|
|
self.db_name = db_name
|
|
self.collection_name = collection_name
|
|
self.field_names = field_names or []
|
|
self.filter_criteria = filter_criteria or {}
|
|
self.metadata_names = metadata_names or []
|
|
self.include_db_collection_in_metadata = include_db_collection_in_metadata
|
|
|
|
self.db = self.client.get_database(db_name)
|
|
self.collection = self.db.get_collection(collection_name)
|
|
|
|
def load(self) -> List[Document]:
|
|
"""Load data into Document objects.
|
|
|
|
Attention:
|
|
|
|
This implementation starts an asyncio event loop which
|
|
will only work if running in a sync env. In an async env, it should
|
|
fail since there is already an event loop running.
|
|
|
|
This code should be updated to kick off the event loop from a separate
|
|
thread if running within an async context.
|
|
"""
|
|
return asyncio.run(self.aload())
|
|
|
|
async def aload(self) -> List[Document]:
|
|
"""Asynchronously loads data into Document objects."""
|
|
result = []
|
|
total_docs = await self.collection.count_documents(self.filter_criteria)
|
|
|
|
projection = self._construct_projection()
|
|
|
|
async for doc in self.collection.find(self.filter_criteria, projection):
|
|
metadata = self._extract_fields(doc, self.metadata_names, default="")
|
|
|
|
# Optionally add database and collection names to metadata
|
|
if self.include_db_collection_in_metadata:
|
|
metadata.update(
|
|
{
|
|
"database": self.db_name,
|
|
"collection": self.collection_name,
|
|
}
|
|
)
|
|
|
|
# Extract text content from filtered fields or use the entire document
|
|
if self.field_names is not None:
|
|
fields = self._extract_fields(doc, self.field_names, default="")
|
|
texts = [str(value) for value in fields.values()]
|
|
text = " ".join(texts)
|
|
else:
|
|
text = str(doc)
|
|
|
|
result.append(Document(page_content=text, metadata=metadata))
|
|
|
|
if len(result) != total_docs:
|
|
logger.warning(
|
|
f"Only partial collection of documents returned. "
|
|
f"Loaded {len(result)} docs, expected {total_docs}."
|
|
)
|
|
|
|
return result
|
|
|
|
def _construct_projection(self) -> Optional[Dict]:
|
|
"""Constructs the projection dictionary for MongoDB query based
|
|
on the specified field names and metadata names."""
|
|
field_names = list(self.field_names) or []
|
|
metadata_names = list(self.metadata_names) or []
|
|
all_fields = field_names + metadata_names
|
|
return {field: 1 for field in all_fields} if all_fields else None
|
|
|
|
def _extract_fields(
|
|
self,
|
|
document: Dict,
|
|
fields: Sequence[str],
|
|
default: str = "",
|
|
) -> Dict:
|
|
"""Extracts and returns values for specified fields from a document."""
|
|
extracted = {}
|
|
for field in fields or []:
|
|
value = document
|
|
for key in field.split("."):
|
|
value = value.get(key, default)
|
|
if value == default:
|
|
break
|
|
new_field_name = field.replace(".", "_")
|
|
extracted[new_field_name] = value
|
|
return extracted
|