mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
community[patch]: update copy of metadata in rockset vectorstore integration (#17612)
- **Description:** This fixes an issue with working with RecordManager. RecordManager was generating new hashes on documents because `add_texts` was modifying the metadata directly. Additionally moved some tests to unit tests since that was a more appropriate home. - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** `@_morgan_adams_`
This commit is contained in:
parent
c8d96f30bd
commit
9d7ca7df6e
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Iterable, List, Optional, Tuple
|
from typing import Any, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
@ -123,7 +124,7 @@ class Rockset(VectorStore):
|
|||||||
batch = []
|
batch = []
|
||||||
doc = {}
|
doc = {}
|
||||||
if metadatas and len(metadatas) > i:
|
if metadatas and len(metadatas) > i:
|
||||||
doc = metadatas[i]
|
doc = deepcopy(metadatas[i])
|
||||||
if ids and len(ids) > i:
|
if ids and len(ids) > i:
|
||||||
doc["_id"] = ids[i]
|
doc["_id"] = ids[i]
|
||||||
doc[self._text_key] = text
|
doc[self._text_key] = text
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
@ -31,10 +32,10 @@ logger = logging.getLogger(__name__)
|
|||||||
#
|
#
|
||||||
# See https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details.
|
# See https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details.
|
||||||
|
|
||||||
workspace = "langchain_tests"
|
WORKSPACE = "morgana"
|
||||||
collection_name = "langchain_demo"
|
COLLECTION_NAME = "langchain_demo"
|
||||||
text_key = "description"
|
TEXT_KEY = "description"
|
||||||
embedding_key = "description_embedding"
|
EMBEDDING_KEY = "description_embedding"
|
||||||
|
|
||||||
|
|
||||||
class TestRockset:
|
class TestRockset:
|
||||||
@ -59,7 +60,7 @@ class TestRockset:
|
|||||||
elif region == "dev":
|
elif region == "dev":
|
||||||
host = rockset.DevRegions.usw2a1
|
host = rockset.DevRegions.usw2a1
|
||||||
else:
|
else:
|
||||||
logger.warn(
|
logger.warning(
|
||||||
"Using ROCKSET_REGION:%s as it is.. \
|
"Using ROCKSET_REGION:%s as it is.. \
|
||||||
You should know what you're doing...",
|
You should know what you're doing...",
|
||||||
region,
|
region,
|
||||||
@ -71,9 +72,9 @@ class TestRockset:
|
|||||||
if os.environ.get("ROCKSET_DELETE_DOCS_ON_START") == "1":
|
if os.environ.get("ROCKSET_DELETE_DOCS_ON_START") == "1":
|
||||||
logger.info(
|
logger.info(
|
||||||
"Deleting all existing documents from the Rockset collection %s",
|
"Deleting all existing documents from the Rockset collection %s",
|
||||||
collection_name,
|
COLLECTION_NAME,
|
||||||
)
|
)
|
||||||
query = f"select _id from {workspace}.{collection_name}"
|
query = f"select _id from {WORKSPACE}.{COLLECTION_NAME}"
|
||||||
|
|
||||||
query_response = client.Queries.query(sql={"query": query})
|
query_response = client.Queries.query(sql={"query": query})
|
||||||
ids = [
|
ids = [
|
||||||
@ -84,15 +85,15 @@ class TestRockset:
|
|||||||
]
|
]
|
||||||
logger.info("Existing ids in collection: %s", ids)
|
logger.info("Existing ids in collection: %s", ids)
|
||||||
client.Documents.delete_documents(
|
client.Documents.delete_documents(
|
||||||
collection=collection_name,
|
collection=COLLECTION_NAME,
|
||||||
data=[rockset.models.DeleteDocumentsRequestData(id=i) for i in ids],
|
data=[rockset.models.DeleteDocumentsRequestData(id=i) for i in ids],
|
||||||
workspace=workspace,
|
workspace=WORKSPACE,
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = ConsistentFakeEmbeddings()
|
embeddings = ConsistentFakeEmbeddings()
|
||||||
embeddings.embed_documents(fake_texts)
|
embeddings.embed_documents(fake_texts)
|
||||||
cls.rockset_vectorstore = Rockset(
|
cls.rockset_vectorstore = Rockset(
|
||||||
client, embeddings, collection_name, text_key, embedding_key, workspace
|
client, embeddings, COLLECTION_NAME, TEXT_KEY, EMBEDDING_KEY, WORKSPACE
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_rockset_insert_and_search(self) -> None:
|
def test_rockset_insert_and_search(self) -> None:
|
||||||
@ -120,42 +121,6 @@ class TestRockset:
|
|||||||
)
|
)
|
||||||
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
|
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
|
||||||
|
|
||||||
def test_build_query_sql(self) -> None:
|
|
||||||
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
|
||||||
q_str = self.rockset_vectorstore._build_query_sql(
|
|
||||||
vector,
|
|
||||||
Rockset.DistanceFunction.COSINE_SIM,
|
|
||||||
4,
|
|
||||||
)
|
|
||||||
vector_str = ",".join(map(str, vector))
|
|
||||||
expected = f"""\
|
|
||||||
SELECT * EXCEPT({embedding_key}), \
|
|
||||||
COSINE_SIM({embedding_key}, [{vector_str}]) as dist
|
|
||||||
FROM {workspace}.{collection_name}
|
|
||||||
ORDER BY dist DESC
|
|
||||||
LIMIT 4
|
|
||||||
"""
|
|
||||||
assert q_str == expected
|
|
||||||
|
|
||||||
def test_build_query_sql_with_where(self) -> None:
|
|
||||||
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
|
||||||
q_str = self.rockset_vectorstore._build_query_sql(
|
|
||||||
vector,
|
|
||||||
Rockset.DistanceFunction.COSINE_SIM,
|
|
||||||
4,
|
|
||||||
"age >= 10",
|
|
||||||
)
|
|
||||||
vector_str = ",".join(map(str, vector))
|
|
||||||
expected = f"""\
|
|
||||||
SELECT * EXCEPT({embedding_key}), \
|
|
||||||
COSINE_SIM({embedding_key}, [{vector_str}]) as dist
|
|
||||||
FROM {workspace}.{collection_name}
|
|
||||||
WHERE age >= 10
|
|
||||||
ORDER BY dist DESC
|
|
||||||
LIMIT 4
|
|
||||||
"""
|
|
||||||
assert q_str == expected
|
|
||||||
|
|
||||||
def test_add_documents_and_delete(self) -> None:
|
def test_add_documents_and_delete(self) -> None:
|
||||||
""" "add_documents" and "delete" are requirements to support use
|
""" "add_documents" and "delete" are requirements to support use
|
||||||
with RecordManager"""
|
with RecordManager"""
|
||||||
@ -171,3 +136,53 @@ LIMIT 4
|
|||||||
|
|
||||||
deleted = self.rockset_vectorstore.delete(ids)
|
deleted = self.rockset_vectorstore.delete(ids)
|
||||||
assert deleted
|
assert deleted
|
||||||
|
|
||||||
|
def test_add_texts_does_not_modify_metadata(self) -> None:
|
||||||
|
"""If metadata changes it will inhibit the langchain RecordManager
|
||||||
|
functionality"""
|
||||||
|
|
||||||
|
texts = ["kitty", "doggy"]
|
||||||
|
metadatas = [{"source": "kitty.txt"}, {"source": "doggy.txt"}]
|
||||||
|
ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
||||||
|
|
||||||
|
self.rockset_vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||||
|
|
||||||
|
for metadata in metadatas:
|
||||||
|
assert len(metadata) == 1
|
||||||
|
assert list(metadata.keys())[0] == "source"
|
||||||
|
|
||||||
|
def test_build_query_sql(self) -> None:
|
||||||
|
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||||
|
q_str = self.rockset_vectorstore._build_query_sql(
|
||||||
|
vector,
|
||||||
|
Rockset.DistanceFunction.COSINE_SIM,
|
||||||
|
4,
|
||||||
|
)
|
||||||
|
vector_str = ",".join(map(str, vector))
|
||||||
|
expected = f"""\
|
||||||
|
SELECT * EXCEPT({EMBEDDING_KEY}), \
|
||||||
|
COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist
|
||||||
|
FROM {WORKSPACE}.{COLLECTION_NAME}
|
||||||
|
ORDER BY dist DESC
|
||||||
|
LIMIT 4
|
||||||
|
"""
|
||||||
|
assert q_str == expected
|
||||||
|
|
||||||
|
def test_build_query_sql_with_where(self) -> None:
|
||||||
|
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||||
|
q_str = self.rockset_vectorstore._build_query_sql(
|
||||||
|
vector,
|
||||||
|
Rockset.DistanceFunction.COSINE_SIM,
|
||||||
|
4,
|
||||||
|
"age >= 10",
|
||||||
|
)
|
||||||
|
vector_str = ",".join(map(str, vector))
|
||||||
|
expected = f"""\
|
||||||
|
SELECT * EXCEPT({EMBEDDING_KEY}), \
|
||||||
|
COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist
|
||||||
|
FROM {WORKSPACE}.{COLLECTION_NAME}
|
||||||
|
WHERE age >= 10
|
||||||
|
ORDER BY dist DESC
|
||||||
|
LIMIT 4
|
||||||
|
"""
|
||||||
|
assert q_str == expected
|
||||||
|
Loading…
Reference in New Issue
Block a user