partners/mongodb : Significant MongoDBVectorSearch ID enhancements (#23535)

## Description

This pull-request improves the treatment of document IDs in
`MongoDBAtlasVectorSearch`.

Class method signatures of add_documents, add_texts, delete, and
from_texts
now include an `ids:Optional[List[str]]` keyword argument permitting the
user
greater control. 
Note that, as before, IDs may also be inferred from
`Document.metadata['_id']`
if present, but this is no longer required,
IDs can also optionally be returned from searches.

This PR closes the following JIRA issues.

* [PYTHON-4446](https://jira.mongodb.org/browse/PYTHON-4446)
MongoDBVectorSearch delete / add_texts function rework
* [PYTHON-4435](https://jira.mongodb.org/browse/PYTHON-4435) Add support
for "Indexing"
* [PYTHON-4534](https://jira.mongodb.org/browse/PYTHON-4534) Ensure
datetimes are json-serializable

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Casey Clements
2024-07-17 13:26:20 -07:00
committed by GitHub
parent cc2cbfabfc
commit a47f69a120
4 changed files with 403 additions and 89 deletions

View File

@@ -8,10 +8,15 @@ are duplicated in this utility respectively from modules:
- "libs/community/langchain_community/utils/math.py"
"""
from __future__ import annotations
import logging
from typing import List, Union
from datetime import date, datetime
from typing import Any, Dict, List, Union
import numpy as np
from bson import ObjectId
from bson.errors import InvalidId
logger = logging.getLogger(__name__)
@@ -88,3 +93,55 @@ def maximal_marginal_relevance(
idxs.append(idx_to_add)
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
return idxs
def str_to_oid(str_repr: str) -> ObjectId | str:
"""Attempt to cast string representation of id to MongoDB's internal BSON ObjectId.
To be consistent with ObjectId, input must be a 24 character hex string.
If it is not, MongoDB will happily use the string in the main _id index.
Importantly, the str representation that comes out of MongoDB will have this form.
Args:
str_repr: id as string.
Returns:
ObjectID
"""
try:
return ObjectId(str_repr)
except InvalidId:
logger.debug(
"ObjectIds must be 12-character byte or 24-character hex strings. "
"Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'"
)
return str_repr
def oid_to_str(oid: ObjectId) -> str:
"""Convert MongoDB's internal BSON ObjectId into a simple str for compatibility.
Instructive helper to show where data is coming out of MongoDB.
Args:
oid: bson.ObjectId
Returns:
24 character hex string.
"""
return str(oid)
def make_serializable(
obj: Dict[str, Any],
) -> None:
"""Recursively cast values in a dict to a form able to json.dump"""
for k, v in obj.items():
if isinstance(v, dict):
make_serializable(v)
elif isinstance(v, list) and v and isinstance(v[0], (ObjectId, date, datetime)):
obj[k] = [oid_to_str(item) for item in v]
elif isinstance(v, ObjectId):
obj[k] = oid_to_str(v)
elif isinstance(v, (datetime, date)):
obj[k] = v.isoformat()