community[patch]: update copy of metadata in rockset vectorstore integration (#17612)

- **Description:** This fixes an issue with working with RecordManager.
RecordManager was generating new hashes on documents because `add_texts`
was modifying the metadata directly. Additionally moved some tests to
unit tests since that was a more appropriate home.
- **Issue:** N/A
- **Dependencies:** N/A
- **Twitter handle:** `@_morgan_adams_`
This commit is contained in:
morgana 2024-02-15 22:13:40 -08:00 committed by GitHub
parent c8d96f30bd
commit 9d7ca7df6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 47 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
from copy import deepcopy
from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple
@ -123,7 +124,7 @@ class Rockset(VectorStore):
batch = []
doc = {}
if metadatas and len(metadatas) > i:
doc = metadatas[i]
doc = deepcopy(metadatas[i])
if ids and len(ids) > i:
doc["_id"] = ids[i]
doc[self._text_key] = text

View File

@ -1,5 +1,6 @@
import logging
import os
import uuid
from langchain_core.documents import Document
@ -31,10 +32,10 @@ logger = logging.getLogger(__name__)
#
# See https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details.
workspace = "langchain_tests"
collection_name = "langchain_demo"
text_key = "description"
embedding_key = "description_embedding"
WORKSPACE = "morgana"
COLLECTION_NAME = "langchain_demo"
TEXT_KEY = "description"
EMBEDDING_KEY = "description_embedding"
class TestRockset:
@ -59,7 +60,7 @@ class TestRockset:
elif region == "dev":
host = rockset.DevRegions.usw2a1
else:
logger.warn(
logger.warning(
"Using ROCKSET_REGION:%s as it is.. \
You should know what you're doing...",
region,
@ -71,9 +72,9 @@ class TestRockset:
if os.environ.get("ROCKSET_DELETE_DOCS_ON_START") == "1":
logger.info(
"Deleting all existing documents from the Rockset collection %s",
collection_name,
COLLECTION_NAME,
)
query = f"select _id from {workspace}.{collection_name}"
query = f"select _id from {WORKSPACE}.{COLLECTION_NAME}"
query_response = client.Queries.query(sql={"query": query})
ids = [
@ -84,15 +85,15 @@ class TestRockset:
]
logger.info("Existing ids in collection: %s", ids)
client.Documents.delete_documents(
collection=collection_name,
collection=COLLECTION_NAME,
data=[rockset.models.DeleteDocumentsRequestData(id=i) for i in ids],
workspace=workspace,
workspace=WORKSPACE,
)
embeddings = ConsistentFakeEmbeddings()
embeddings.embed_documents(fake_texts)
cls.rockset_vectorstore = Rockset(
client, embeddings, collection_name, text_key, embedding_key, workspace
client, embeddings, COLLECTION_NAME, TEXT_KEY, EMBEDDING_KEY, WORKSPACE
)
def test_rockset_insert_and_search(self) -> None:
@ -120,42 +121,6 @@ class TestRockset:
)
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
def test_build_query_sql(self) -> None:
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
q_str = self.rockset_vectorstore._build_query_sql(
vector,
Rockset.DistanceFunction.COSINE_SIM,
4,
)
vector_str = ",".join(map(str, vector))
expected = f"""\
SELECT * EXCEPT({embedding_key}), \
COSINE_SIM({embedding_key}, [{vector_str}]) as dist
FROM {workspace}.{collection_name}
ORDER BY dist DESC
LIMIT 4
"""
assert q_str == expected
def test_build_query_sql_with_where(self) -> None:
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
q_str = self.rockset_vectorstore._build_query_sql(
vector,
Rockset.DistanceFunction.COSINE_SIM,
4,
"age >= 10",
)
vector_str = ",".join(map(str, vector))
expected = f"""\
SELECT * EXCEPT({embedding_key}), \
COSINE_SIM({embedding_key}, [{vector_str}]) as dist
FROM {workspace}.{collection_name}
WHERE age >= 10
ORDER BY dist DESC
LIMIT 4
"""
assert q_str == expected
def test_add_documents_and_delete(self) -> None:
""" "add_documents" and "delete" are requirements to support use
with RecordManager"""
@ -171,3 +136,53 @@ LIMIT 4
deleted = self.rockset_vectorstore.delete(ids)
assert deleted
def test_add_texts_does_not_modify_metadata(self) -> None:
"""If metadata changes it will inhibit the langchain RecordManager
functionality"""
texts = ["kitty", "doggy"]
metadatas = [{"source": "kitty.txt"}, {"source": "doggy.txt"}]
ids = [str(uuid.uuid4()), str(uuid.uuid4())]
self.rockset_vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids)
for metadata in metadatas:
assert len(metadata) == 1
assert list(metadata.keys())[0] == "source"
def test_build_query_sql(self) -> None:
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
q_str = self.rockset_vectorstore._build_query_sql(
vector,
Rockset.DistanceFunction.COSINE_SIM,
4,
)
vector_str = ",".join(map(str, vector))
expected = f"""\
SELECT * EXCEPT({EMBEDDING_KEY}), \
COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist
FROM {WORKSPACE}.{COLLECTION_NAME}
ORDER BY dist DESC
LIMIT 4
"""
assert q_str == expected
def test_build_query_sql_with_where(self) -> None:
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
q_str = self.rockset_vectorstore._build_query_sql(
vector,
Rockset.DistanceFunction.COSINE_SIM,
4,
"age >= 10",
)
vector_str = ",".join(map(str, vector))
expected = f"""\
SELECT * EXCEPT({EMBEDDING_KEY}), \
COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist
FROM {WORKSPACE}.{COLLECTION_NAME}
WHERE age >= 10
ORDER BY dist DESC
LIMIT 4
"""
assert q_str == expected