mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-05 15:18:32 +00:00
**Description:** This PR addresses an issue in the `MongodbLoader` where nested fields were not being correctly extracted. The loader now correctly handles nested fields specified in the `field_names` parameter. **Issue:** Fixes an issue where attempting to extract nested fields from MongoDB documents resulted in `KeyError`. **Dependencies:** No new dependencies are required for this change. **Twitter handle:** (Optional, your Twitter handle if you'd like a mention when the PR is announced) ### Changes 1. **Field Name Parsing**: - Added logic to parse nested field names and safely extract their values from the MongoDB documents. 2. **Projection Construction**: - Updated the projection dictionary to include nested fields correctly. 3. **Field Extraction**: - Updated the `aload` method to handle nested field extraction using a recursive approach to traverse the nested dictionaries. ### Example Usage Updated usage example to demonstrate how to specify nested fields in the `field_names` parameter: ```python loader = MongodbLoader( connection_string=MONGO_URI, db_name=MONGO_DB, collection_name=MONGO_COLLECTION, filter_criteria={"data.job.company.industry_name": "IT", "data.job.detail": { "$exists": True }}, field_names=[ "data.job.detail.id", "data.job.detail.position", "data.job.detail.intro", "data.job.detail.main_tasks", "data.job.detail.requirements", "data.job.detail.preferred_points", "data.job.detail.benefits", ], ) docs = loader.load() print(len(docs)) for doc in docs: print(doc.page_content) ``` ### Testing Tested with a MongoDB collection containing nested documents to ensure that the nested fields are correctly extracted and concatenated into a single page_content string. ### Note This change ensures backward compatibility for non-nested fields and improves functionality for nested field extraction. ### Output Sample ```python print(docs[:3]) ``` ```shell # output sample: [ Document( # Here in this example, page_content is the combined text from the fields below # "position", "intro", "main_tasks", "requirements", "preferred_points", "benefits" page_content='all combined contents from the requested fields in the document', metadata={'database': 'Your Database name', 'collection': 'Your Collection name'} ), ... ] ``` --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
107 lines
3.5 KiB
Python
107 lines
3.5 KiB
Python
import asyncio
|
|
import logging
|
|
from typing import Dict, List, Optional, Sequence
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
from langchain_community.document_loaders.base import BaseLoader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MongodbLoader(BaseLoader):
|
|
"""Load MongoDB documents."""
|
|
|
|
def __init__(
|
|
self,
|
|
connection_string: str,
|
|
db_name: str,
|
|
collection_name: str,
|
|
*,
|
|
filter_criteria: Optional[Dict] = None,
|
|
field_names: Optional[Sequence[str]] = None,
|
|
) -> None:
|
|
try:
|
|
from motor.motor_asyncio import AsyncIOMotorClient
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Cannot import from motor, please install with `pip install motor`."
|
|
) from e
|
|
if not connection_string:
|
|
raise ValueError("connection_string must be provided.")
|
|
|
|
if not db_name:
|
|
raise ValueError("db_name must be provided.")
|
|
|
|
if not collection_name:
|
|
raise ValueError("collection_name must be provided.")
|
|
|
|
self.client = AsyncIOMotorClient(connection_string)
|
|
self.db_name = db_name
|
|
self.collection_name = collection_name
|
|
self.field_names = field_names
|
|
self.filter_criteria = filter_criteria or {}
|
|
|
|
self.db = self.client.get_database(db_name)
|
|
self.collection = self.db.get_collection(collection_name)
|
|
|
|
def load(self) -> List[Document]:
|
|
"""Load data into Document objects.
|
|
|
|
Attention:
|
|
|
|
This implementation starts an asyncio event loop which
|
|
will only work if running in a sync env. In an async env, it should
|
|
fail since there is already an event loop running.
|
|
|
|
This code should be updated to kick off the event loop from a separate
|
|
thread if running within an async context.
|
|
"""
|
|
return asyncio.run(self.aload())
|
|
|
|
async def aload(self) -> List[Document]:
|
|
"""Load data into Document objects."""
|
|
result = []
|
|
total_docs = await self.collection.count_documents(self.filter_criteria)
|
|
|
|
# Construct the projection dictionary if field_names are specified
|
|
projection = (
|
|
{field: 1 for field in self.field_names} if self.field_names else None
|
|
)
|
|
|
|
async for doc in self.collection.find(self.filter_criteria, projection):
|
|
metadata = {
|
|
"database": self.db_name,
|
|
"collection": self.collection_name,
|
|
}
|
|
|
|
# Extract text content from filtered fields or use the entire document
|
|
if self.field_names is not None:
|
|
fields = {}
|
|
for name in self.field_names:
|
|
# Split the field names to handle nested fields
|
|
keys = name.split(".")
|
|
value = doc
|
|
for key in keys:
|
|
if key in value:
|
|
value = value[key]
|
|
else:
|
|
value = ""
|
|
break
|
|
fields[name] = value
|
|
|
|
texts = [str(value) for value in fields.values()]
|
|
text = " ".join(texts)
|
|
else:
|
|
text = str(doc)
|
|
|
|
result.append(Document(page_content=text, metadata=metadata))
|
|
|
|
if len(result) != total_docs:
|
|
logger.warning(
|
|
f"Only partial collection of documents returned. "
|
|
f"Loaded {len(result)} docs, expected {total_docs}."
|
|
)
|
|
|
|
return result
|