mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
mongodb[patch]: Make ObjectId JSON-serializable on generation (#21394)
This commit is contained in:
parent
12b599c47f
commit
a97473c846
@ -16,6 +16,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from bson import ObjectId, json_util
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
@ -210,9 +211,23 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
pipeline.extend(post_filter_pipeline)
|
||||
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
||||
docs = []
|
||||
|
||||
def _make_serializable(obj: Dict[str, Any]) -> None:
|
||||
for k, v in obj.items():
|
||||
if isinstance(v, dict):
|
||||
_make_serializable(v)
|
||||
elif isinstance(v, list) and v and isinstance(v[0], ObjectId):
|
||||
obj[k] = [json_util.default(item) for item in v]
|
||||
elif isinstance(v, ObjectId):
|
||||
obj[k] = json_util.default(v)
|
||||
|
||||
for res in cursor:
|
||||
text = res.pop(self._text_key)
|
||||
score = res.pop("score")
|
||||
# Make every ObjectId found JSON-Serializable
|
||||
# following format used in bson.json_util.loads
|
||||
# e.g. loads('{"_id": {"$oid": "664..."}}') == {'_id': ObjectId('664..')} # noqa: E501
|
||||
_make_serializable(res)
|
||||
docs.append((Document(page_content=text, metadata=res), score))
|
||||
return docs
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
from json import dumps, loads
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId, json_util
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pymongo.collection import Collection
|
||||
@ -75,6 +77,11 @@ class TestMongoDBAtlasVectorSearch:
|
||||
output = vectorstore.similarity_search("", k=1)
|
||||
assert output[0].page_content == page_content
|
||||
assert output[0].metadata.get("c") == metadata
|
||||
# Validate the ObjectId provided is json serializable
|
||||
assert loads(dumps(output[0].page_content)) == output[0].page_content
|
||||
assert loads(dumps(output[0].metadata)) == output[0].metadata
|
||||
json_metadata = dumps(output[0].metadata) # normal json.dumps
|
||||
assert isinstance(json_util.loads(json_metadata)["_id"], ObjectId)
|
||||
|
||||
def test_from_documents(
|
||||
self, embedding_openai: Embeddings, collection: MockCollection
|
||||
|
@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Mapping, Optional, cast
|
||||
|
||||
from bson import ObjectId
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@ -162,7 +162,7 @@ class MockCollection(Collection):
|
||||
|
||||
def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore
|
||||
mongodb_inserts = [
|
||||
{"_id": str(uuid.uuid4()), "score": 1, **insert} for insert in to_insert
|
||||
{"_id": ObjectId(), "score": 1, **insert} for insert in to_insert
|
||||
]
|
||||
self._data.extend(mongodb_inserts)
|
||||
return self._insert_result or InsertManyResult(
|
||||
|
Loading…
Reference in New Issue
Block a user