From a97473c846919742e0d752fc145dcdfebe9c289b Mon Sep 17 00:00:00 2001 From: Jib Date: Tue, 14 May 2024 14:52:29 -0400 Subject: [PATCH] mongodb[patch]: Make ObjectId JSON-serializable on generation (#21394) --- .../mongodb/langchain_mongodb/vectorstores.py | 15 +++++++++++++++ .../mongodb/tests/unit_tests/test_vectorstores.py | 7 +++++++ libs/partners/mongodb/tests/utils.py | 4 ++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/libs/partners/mongodb/langchain_mongodb/vectorstores.py b/libs/partners/mongodb/langchain_mongodb/vectorstores.py index c87b60b7030..1977045ccbd 100644 --- a/libs/partners/mongodb/langchain_mongodb/vectorstores.py +++ b/libs/partners/mongodb/langchain_mongodb/vectorstores.py @@ -16,6 +16,7 @@ from typing import ( ) import numpy as np +from bson import ObjectId, json_util from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.runnables.config import run_in_executor @@ -210,9 +211,23 @@ class MongoDBAtlasVectorSearch(VectorStore): pipeline.extend(post_filter_pipeline) cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type] docs = [] + + def _make_serializable(obj: Dict[str, Any]) -> None: + for k, v in obj.items(): + if isinstance(v, dict): + _make_serializable(v) + elif isinstance(v, list) and v and isinstance(v[0], ObjectId): + obj[k] = [json_util.default(item) for item in v] + elif isinstance(v, ObjectId): + obj[k] = json_util.default(v) + for res in cursor: text = res.pop(self._text_key) score = res.pop("score") + # Make every ObjectId found JSON-Serializable + # following format used in bson.json_util.loads + # e.g. loads('{"_id": {"$oid": "664..."}}') == {'_id': ObjectId('664..')} # noqa: E501 + _make_serializable(res) docs.append((Document(page_content=text, metadata=res), score)) return docs diff --git a/libs/partners/mongodb/tests/unit_tests/test_vectorstores.py b/libs/partners/mongodb/tests/unit_tests/test_vectorstores.py index 9d3def60462..9c6c781208c 100644 --- a/libs/partners/mongodb/tests/unit_tests/test_vectorstores.py +++ b/libs/partners/mongodb/tests/unit_tests/test_vectorstores.py @@ -1,6 +1,8 @@ +from json import dumps, loads from typing import Any, Optional import pytest +from bson import ObjectId, json_util from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from pymongo.collection import Collection @@ -75,6 +77,11 @@ class TestMongoDBAtlasVectorSearch: output = vectorstore.similarity_search("", k=1) assert output[0].page_content == page_content assert output[0].metadata.get("c") == metadata + # Validate the ObjectId provided is json serializable + assert loads(dumps(output[0].page_content)) == output[0].page_content + assert loads(dumps(output[0].metadata)) == output[0].metadata + json_metadata = dumps(output[0].metadata) # normal json.dumps + assert isinstance(json_util.loads(json_metadata)["_id"], ObjectId) def test_from_documents( self, embedding_openai: Embeddings, collection: MockCollection diff --git a/libs/partners/mongodb/tests/utils.py b/libs/partners/mongodb/tests/utils.py index 3716ac69478..7b06991da82 100644 --- a/libs/partners/mongodb/tests/utils.py +++ b/libs/partners/mongodb/tests/utils.py @@ -1,9 +1,9 @@ from __future__ import annotations -import uuid from copy import deepcopy from typing import Any, Dict, List, Mapping, Optional, cast +from bson import ObjectId from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -162,7 +162,7 @@ class MockCollection(Collection): def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore mongodb_inserts = [ - {"_id": str(uuid.uuid4()), "score": 1, **insert} for insert in to_insert + {"_id": ObjectId(), "score": 1, **insert} for insert in to_insert ] self._data.extend(mongodb_inserts) return self._insert_result or InsertManyResult(