mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
mongodb[patch]: Make ObjectId JSON-serializable on generation (#21394)
This commit is contained in:
parent
12b599c47f
commit
a97473c846
@ -16,6 +16,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from bson import ObjectId, json_util
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
@ -210,9 +211,23 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
pipeline.extend(post_filter_pipeline)
|
pipeline.extend(post_filter_pipeline)
|
||||||
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
|
def _make_serializable(obj: Dict[str, Any]) -> None:
|
||||||
|
for k, v in obj.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
_make_serializable(v)
|
||||||
|
elif isinstance(v, list) and v and isinstance(v[0], ObjectId):
|
||||||
|
obj[k] = [json_util.default(item) for item in v]
|
||||||
|
elif isinstance(v, ObjectId):
|
||||||
|
obj[k] = json_util.default(v)
|
||||||
|
|
||||||
for res in cursor:
|
for res in cursor:
|
||||||
text = res.pop(self._text_key)
|
text = res.pop(self._text_key)
|
||||||
score = res.pop("score")
|
score = res.pop("score")
|
||||||
|
# Make every ObjectId found JSON-Serializable
|
||||||
|
# following format used in bson.json_util.loads
|
||||||
|
# e.g. loads('{"_id": {"$oid": "664..."}}') == {'_id': ObjectId('664..')} # noqa: E501
|
||||||
|
_make_serializable(res)
|
||||||
docs.append((Document(page_content=text, metadata=res), score))
|
docs.append((Document(page_content=text, metadata=res), score))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
from json import dumps, loads
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from bson import ObjectId, json_util
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from pymongo.collection import Collection
|
from pymongo.collection import Collection
|
||||||
@ -75,6 +77,11 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
output = vectorstore.similarity_search("", k=1)
|
output = vectorstore.similarity_search("", k=1)
|
||||||
assert output[0].page_content == page_content
|
assert output[0].page_content == page_content
|
||||||
assert output[0].metadata.get("c") == metadata
|
assert output[0].metadata.get("c") == metadata
|
||||||
|
# Validate the ObjectId provided is json serializable
|
||||||
|
assert loads(dumps(output[0].page_content)) == output[0].page_content
|
||||||
|
assert loads(dumps(output[0].metadata)) == output[0].metadata
|
||||||
|
json_metadata = dumps(output[0].metadata) # normal json.dumps
|
||||||
|
assert isinstance(json_util.loads(json_metadata)["_id"], ObjectId)
|
||||||
|
|
||||||
def test_from_documents(
|
def test_from_documents(
|
||||||
self, embedding_openai: Embeddings, collection: MockCollection
|
self, embedding_openai: Embeddings, collection: MockCollection
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, List, Mapping, Optional, cast
|
from typing import Any, Dict, List, Mapping, Optional, cast
|
||||||
|
|
||||||
|
from bson import ObjectId
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@ -162,7 +162,7 @@ class MockCollection(Collection):
|
|||||||
|
|
||||||
def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore
|
def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore
|
||||||
mongodb_inserts = [
|
mongodb_inserts = [
|
||||||
{"_id": str(uuid.uuid4()), "score": 1, **insert} for insert in to_insert
|
{"_id": ObjectId(), "score": 1, **insert} for insert in to_insert
|
||||||
]
|
]
|
||||||
self._data.extend(mongodb_inserts)
|
self._data.extend(mongodb_inserts)
|
||||||
return self._insert_result or InsertManyResult(
|
return self._insert_result or InsertManyResult(
|
||||||
|
Loading…
Reference in New Issue
Block a user