mongodb[patch]: Make ObjectId JSON-serializable on generation (#21394)

This commit is contained in:
Jib 2024-05-14 14:52:29 -04:00 committed by GitHub
parent 12b599c47f
commit a97473c846
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 2 deletions

View File

@ -16,6 +16,7 @@ from typing import (
) )
import numpy as np import numpy as np
from bson import ObjectId, json_util
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
@ -210,9 +211,23 @@ class MongoDBAtlasVectorSearch(VectorStore):
pipeline.extend(post_filter_pipeline) pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type] cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
docs = [] docs = []
def _make_serializable(obj: Dict[str, Any]) -> None:
for k, v in obj.items():
if isinstance(v, dict):
_make_serializable(v)
elif isinstance(v, list) and v and isinstance(v[0], ObjectId):
obj[k] = [json_util.default(item) for item in v]
elif isinstance(v, ObjectId):
obj[k] = json_util.default(v)
for res in cursor: for res in cursor:
text = res.pop(self._text_key) text = res.pop(self._text_key)
score = res.pop("score") score = res.pop("score")
# Make every ObjectId found JSON-Serializable
# following format used in bson.json_util.loads
# e.g. loads('{"_id": {"$oid": "664..."}}') == {'_id': ObjectId('664..')} # noqa: E501
_make_serializable(res)
docs.append((Document(page_content=text, metadata=res), score)) docs.append((Document(page_content=text, metadata=res), score))
return docs return docs

View File

@ -1,6 +1,8 @@
from json import dumps, loads
from typing import Any, Optional from typing import Any, Optional
import pytest import pytest
from bson import ObjectId, json_util
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from pymongo.collection import Collection from pymongo.collection import Collection
@ -75,6 +77,11 @@ class TestMongoDBAtlasVectorSearch:
output = vectorstore.similarity_search("", k=1) output = vectorstore.similarity_search("", k=1)
assert output[0].page_content == page_content assert output[0].page_content == page_content
assert output[0].metadata.get("c") == metadata assert output[0].metadata.get("c") == metadata
# Validate the ObjectId provided is json serializable
assert loads(dumps(output[0].page_content)) == output[0].page_content
assert loads(dumps(output[0].metadata)) == output[0].metadata
json_metadata = dumps(output[0].metadata) # normal json.dumps
assert isinstance(json_util.loads(json_metadata)["_id"], ObjectId)
def test_from_documents( def test_from_documents(
self, embedding_openai: Embeddings, collection: MockCollection self, embedding_openai: Embeddings, collection: MockCollection

View File

@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
import uuid
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Mapping, Optional, cast from typing import Any, Dict, List, Mapping, Optional, cast
from bson import ObjectId
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
@ -162,7 +162,7 @@ class MockCollection(Collection):
def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore
mongodb_inserts = [ mongodb_inserts = [
{"_id": str(uuid.uuid4()), "score": 1, **insert} for insert in to_insert {"_id": ObjectId(), "score": 1, **insert} for insert in to_insert
] ]
self._data.extend(mongodb_inserts) self._data.extend(mongodb_inserts)
return self._insert_result or InsertManyResult( return self._insert_result or InsertManyResult(