community[patch]: update copy of metadata in rockset vectorstore integration (#17612)

- **Description:** This fixes an issue with working with RecordManager.
RecordManager was generating new hashes on documents because `add_texts`
was modifying the metadata directly. Additionally moved some tests to
unit tests since that was a more appropriate home.
- **Issue:** N/A
- **Dependencies:** N/A
- **Twitter handle:** `@_morgan_adams_`
This commit is contained in:
morgana 2024-02-15 22:13:40 -08:00 committed by GitHub
parent c8d96f30bd
commit 9d7ca7df6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 47 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
@ -123,7 +124,7 @@ class Rockset(VectorStore):
batch = [] batch = []
doc = {} doc = {}
if metadatas and len(metadatas) > i: if metadatas and len(metadatas) > i:
doc = metadatas[i] doc = deepcopy(metadatas[i])
if ids and len(ids) > i: if ids and len(ids) > i:
doc["_id"] = ids[i] doc["_id"] = ids[i]
doc[self._text_key] = text doc[self._text_key] = text

View File

@ -1,5 +1,6 @@
import logging import logging
import os import os
import uuid
from langchain_core.documents import Document from langchain_core.documents import Document
@ -31,10 +32,10 @@ logger = logging.getLogger(__name__)
# #
# See https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details. # See https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details.
workspace = "langchain_tests" WORKSPACE = "morgana"
collection_name = "langchain_demo" COLLECTION_NAME = "langchain_demo"
text_key = "description" TEXT_KEY = "description"
embedding_key = "description_embedding" EMBEDDING_KEY = "description_embedding"
class TestRockset: class TestRockset:
@ -59,7 +60,7 @@ class TestRockset:
elif region == "dev": elif region == "dev":
host = rockset.DevRegions.usw2a1 host = rockset.DevRegions.usw2a1
else: else:
logger.warn( logger.warning(
"Using ROCKSET_REGION:%s as it is.. \ "Using ROCKSET_REGION:%s as it is.. \
You should know what you're doing...", You should know what you're doing...",
region, region,
@ -71,9 +72,9 @@ class TestRockset:
if os.environ.get("ROCKSET_DELETE_DOCS_ON_START") == "1": if os.environ.get("ROCKSET_DELETE_DOCS_ON_START") == "1":
logger.info( logger.info(
"Deleting all existing documents from the Rockset collection %s", "Deleting all existing documents from the Rockset collection %s",
collection_name, COLLECTION_NAME,
) )
query = f"select _id from {workspace}.{collection_name}" query = f"select _id from {WORKSPACE}.{COLLECTION_NAME}"
query_response = client.Queries.query(sql={"query": query}) query_response = client.Queries.query(sql={"query": query})
ids = [ ids = [
@ -84,15 +85,15 @@ class TestRockset:
] ]
logger.info("Existing ids in collection: %s", ids) logger.info("Existing ids in collection: %s", ids)
client.Documents.delete_documents( client.Documents.delete_documents(
collection=collection_name, collection=COLLECTION_NAME,
data=[rockset.models.DeleteDocumentsRequestData(id=i) for i in ids], data=[rockset.models.DeleteDocumentsRequestData(id=i) for i in ids],
workspace=workspace, workspace=WORKSPACE,
) )
embeddings = ConsistentFakeEmbeddings() embeddings = ConsistentFakeEmbeddings()
embeddings.embed_documents(fake_texts) embeddings.embed_documents(fake_texts)
cls.rockset_vectorstore = Rockset( cls.rockset_vectorstore = Rockset(
client, embeddings, collection_name, text_key, embedding_key, workspace client, embeddings, COLLECTION_NAME, TEXT_KEY, EMBEDDING_KEY, WORKSPACE
) )
def test_rockset_insert_and_search(self) -> None: def test_rockset_insert_and_search(self) -> None:
@ -120,42 +121,6 @@ class TestRockset:
) )
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})] assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
def test_build_query_sql(self) -> None:
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
q_str = self.rockset_vectorstore._build_query_sql(
vector,
Rockset.DistanceFunction.COSINE_SIM,
4,
)
vector_str = ",".join(map(str, vector))
expected = f"""\
SELECT * EXCEPT({embedding_key}), \
COSINE_SIM({embedding_key}, [{vector_str}]) as dist
FROM {workspace}.{collection_name}
ORDER BY dist DESC
LIMIT 4
"""
assert q_str == expected
def test_build_query_sql_with_where(self) -> None:
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
q_str = self.rockset_vectorstore._build_query_sql(
vector,
Rockset.DistanceFunction.COSINE_SIM,
4,
"age >= 10",
)
vector_str = ",".join(map(str, vector))
expected = f"""\
SELECT * EXCEPT({embedding_key}), \
COSINE_SIM({embedding_key}, [{vector_str}]) as dist
FROM {workspace}.{collection_name}
WHERE age >= 10
ORDER BY dist DESC
LIMIT 4
"""
assert q_str == expected
def test_add_documents_and_delete(self) -> None: def test_add_documents_and_delete(self) -> None:
""" "add_documents" and "delete" are requirements to support use """ "add_documents" and "delete" are requirements to support use
with RecordManager""" with RecordManager"""
@ -171,3 +136,53 @@ LIMIT 4
deleted = self.rockset_vectorstore.delete(ids) deleted = self.rockset_vectorstore.delete(ids)
assert deleted assert deleted
def test_add_texts_does_not_modify_metadata(self) -> None:
"""If metadata changes it will inhibit the langchain RecordManager
functionality"""
texts = ["kitty", "doggy"]
metadatas = [{"source": "kitty.txt"}, {"source": "doggy.txt"}]
ids = [str(uuid.uuid4()), str(uuid.uuid4())]
self.rockset_vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids)
for metadata in metadatas:
assert len(metadata) == 1
assert list(metadata.keys())[0] == "source"
def test_build_query_sql(self) -> None:
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
q_str = self.rockset_vectorstore._build_query_sql(
vector,
Rockset.DistanceFunction.COSINE_SIM,
4,
)
vector_str = ",".join(map(str, vector))
expected = f"""\
SELECT * EXCEPT({EMBEDDING_KEY}), \
COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist
FROM {WORKSPACE}.{COLLECTION_NAME}
ORDER BY dist DESC
LIMIT 4
"""
assert q_str == expected
def test_build_query_sql_with_where(self) -> None:
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
q_str = self.rockset_vectorstore._build_query_sql(
vector,
Rockset.DistanceFunction.COSINE_SIM,
4,
"age >= 10",
)
vector_str = ",".join(map(str, vector))
expected = f"""\
SELECT * EXCEPT({EMBEDDING_KEY}), \
COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist
FROM {WORKSPACE}.{COLLECTION_NAME}
WHERE age >= 10
ORDER BY dist DESC
LIMIT 4
"""
assert q_str == expected