mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
Add async methods for the AstraDB VectorStore (#16391)
- **Description**: fully async versions are available for astrapy 0.7+. For older astrapy versions or if the user provides a sync client without an async one, the async methods will call the sync ones wrapped in `run_in_executor` - **Twitter handle:** cbornet_
This commit is contained in:
parent
f8f2649f12
commit
744070ee85
File diff suppressed because it is too large
Load Diff
@ -148,6 +148,33 @@ class TestAstraDB:
|
||||
)
|
||||
v_store_2.delete_collection()
|
||||
|
||||
async def test_astradb_vectorstore_create_delete_async(self) -> None:
|
||||
"""Create and delete."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# creation by passing the connection secrets
|
||||
v_store = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_1_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
await v_store.adelete_collection()
|
||||
# Creation by passing a ready-made astrapy client:
|
||||
from astrapy.db import AsyncAstraDB
|
||||
|
||||
astra_db_client = AsyncAstraDB(
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
v_store_2 = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_2_async",
|
||||
async_astra_db_client=astra_db_client,
|
||||
)
|
||||
await v_store_2.adelete_collection()
|
||||
|
||||
def test_astradb_vectorstore_pre_delete_collection(self) -> None:
|
||||
"""Create and delete."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
@ -183,6 +210,41 @@ class TestAstraDB:
|
||||
finally:
|
||||
v_store.delete_collection()
|
||||
|
||||
async def test_astradb_vectorstore_pre_delete_collection_async(self) -> None:
|
||||
"""Create and delete."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# creation by passing the connection secrets
|
||||
|
||||
v_store = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_pre_del_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
try:
|
||||
await v_store.aadd_texts(
|
||||
texts=["aa"],
|
||||
metadatas=[
|
||||
{"k": "a", "ord": 0},
|
||||
],
|
||||
ids=["a"],
|
||||
)
|
||||
res1 = await v_store.asimilarity_search("aa", k=5)
|
||||
assert len(res1) == 1
|
||||
v_store = AstraDB(
|
||||
embedding=emb,
|
||||
pre_delete_collection=True,
|
||||
collection_name="lc_test_pre_del_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
res1 = await v_store.asimilarity_search("aa", k=5)
|
||||
assert len(res1) == 0
|
||||
finally:
|
||||
await v_store.adelete_collection()
|
||||
|
||||
def test_astradb_vectorstore_from_x(self) -> None:
|
||||
"""from_texts and from_documents methods."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
@ -200,7 +262,7 @@ class TestAstraDB:
|
||||
finally:
|
||||
v_store.delete_collection()
|
||||
|
||||
# from_texts
|
||||
# from_documents
|
||||
v_store_2 = AstraDB.from_documents(
|
||||
[
|
||||
Document(page_content="Hee"),
|
||||
@ -217,6 +279,42 @@ class TestAstraDB:
|
||||
finally:
|
||||
v_store_2.delete_collection()
|
||||
|
||||
async def test_astradb_vectorstore_from_x_async(self) -> None:
|
||||
"""from_texts and from_documents methods."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# from_texts
|
||||
v_store = await AstraDB.afrom_texts(
|
||||
texts=["Hi", "Ho"],
|
||||
embedding=emb,
|
||||
collection_name="lc_test_ft_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
try:
|
||||
assert (await v_store.asimilarity_search("Ho", k=1))[0].page_content == "Ho"
|
||||
finally:
|
||||
await v_store.adelete_collection()
|
||||
|
||||
# from_documents
|
||||
v_store_2 = await AstraDB.afrom_documents(
|
||||
[
|
||||
Document(page_content="Hee"),
|
||||
Document(page_content="Hoi"),
|
||||
],
|
||||
embedding=emb,
|
||||
collection_name="lc_test_fd_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
try:
|
||||
assert (await v_store_2.asimilarity_search("Hoi", k=1))[
|
||||
0
|
||||
].page_content == "Hoi"
|
||||
finally:
|
||||
await v_store_2.adelete_collection()
|
||||
|
||||
def test_astradb_vectorstore_crud(self, store_someemb: AstraDB) -> None:
|
||||
"""Basic add/delete/update behaviour."""
|
||||
res0 = store_someemb.similarity_search("Abc", k=2)
|
||||
@ -275,25 +373,106 @@ class TestAstraDB:
|
||||
res4 = store_someemb.similarity_search("ww", k=1, filter={"k": "w"})
|
||||
assert res4[0].metadata["ord"] == 205
|
||||
|
||||
async def test_astradb_vectorstore_crud_async(self, store_someemb: AstraDB) -> None:
|
||||
"""Basic add/delete/update behaviour."""
|
||||
res0 = await store_someemb.asimilarity_search("Abc", k=2)
|
||||
assert res0 == []
|
||||
# write and check again
|
||||
await store_someemb.aadd_texts(
|
||||
texts=["aa", "bb", "cc"],
|
||||
metadatas=[
|
||||
{"k": "a", "ord": 0},
|
||||
{"k": "b", "ord": 1},
|
||||
{"k": "c", "ord": 2},
|
||||
],
|
||||
ids=["a", "b", "c"],
|
||||
)
|
||||
res1 = await store_someemb.asimilarity_search("Abc", k=5)
|
||||
assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"}
|
||||
# partial overwrite and count total entries
|
||||
await store_someemb.aadd_texts(
|
||||
texts=["cc", "dd"],
|
||||
metadatas=[
|
||||
{"k": "c_new", "ord": 102},
|
||||
{"k": "d_new", "ord": 103},
|
||||
],
|
||||
ids=["c", "d"],
|
||||
)
|
||||
res2 = await store_someemb.asimilarity_search("Abc", k=10)
|
||||
assert len(res2) == 4
|
||||
# pick one that was just updated and check its metadata
|
||||
res3 = await store_someemb.asimilarity_search_with_score_id(
|
||||
query="cc", k=1, filter={"k": "c_new"}
|
||||
)
|
||||
print(str(res3))
|
||||
doc3, score3, id3 = res3[0]
|
||||
assert doc3.page_content == "cc"
|
||||
assert doc3.metadata == {"k": "c_new", "ord": 102}
|
||||
assert score3 > 0.999 # leaving some leeway for approximations...
|
||||
assert id3 == "c"
|
||||
# delete and count again
|
||||
del1_res = await store_someemb.adelete(["b"])
|
||||
assert del1_res is True
|
||||
del2_res = await store_someemb.adelete(["a", "c", "Z!"])
|
||||
assert del2_res is False # a non-existing ID was supplied
|
||||
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 1
|
||||
# clear store
|
||||
await store_someemb.aclear()
|
||||
assert await store_someemb.asimilarity_search("Abc", k=2) == []
|
||||
# add_documents with "ids" arg passthrough
|
||||
await store_someemb.aadd_documents(
|
||||
[
|
||||
Document(page_content="vv", metadata={"k": "v", "ord": 204}),
|
||||
Document(page_content="ww", metadata={"k": "w", "ord": 205}),
|
||||
],
|
||||
ids=["v", "w"],
|
||||
)
|
||||
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 2
|
||||
res4 = await store_someemb.asimilarity_search("ww", k=1, filter={"k": "w"})
|
||||
assert res4[0].metadata["ord"] == 205
|
||||
|
||||
@staticmethod
|
||||
def _v_from_i(i: int, N: int) -> str:
|
||||
angle = 2 * math.pi * i / N
|
||||
vector = [math.cos(angle), math.sin(angle)]
|
||||
return json.dumps(vector)
|
||||
|
||||
def test_astradb_vectorstore_mmr(self, store_parseremb: AstraDB) -> None:
|
||||
"""
|
||||
MMR testing. We work on the unit circle with angle multiples
|
||||
of 2*pi/20 and prepare a store with known vectors for a controlled
|
||||
MMR outcome.
|
||||
"""
|
||||
|
||||
def _v_from_i(i: int, N: int) -> str:
|
||||
angle = 2 * math.pi * i / N
|
||||
vector = [math.cos(angle), math.sin(angle)]
|
||||
return json.dumps(vector)
|
||||
|
||||
i_vals = [0, 4, 5, 13]
|
||||
N_val = 20
|
||||
store_parseremb.add_texts(
|
||||
[_v_from_i(i, N_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals]
|
||||
[self._v_from_i(i, N_val) for i in i_vals],
|
||||
metadatas=[{"i": i} for i in i_vals],
|
||||
)
|
||||
res1 = store_parseremb.max_marginal_relevance_search(
|
||||
_v_from_i(3, N_val),
|
||||
self._v_from_i(3, N_val),
|
||||
k=2,
|
||||
fetch_k=3,
|
||||
)
|
||||
res_i_vals = {doc.metadata["i"] for doc in res1}
|
||||
assert res_i_vals == {0, 4}
|
||||
|
||||
async def test_astradb_vectorstore_mmr_async(
|
||||
self, store_parseremb: AstraDB
|
||||
) -> None:
|
||||
"""
|
||||
MMR testing. We work on the unit circle with angle multiples
|
||||
of 2*pi/20 and prepare a store with known vectors for a controlled
|
||||
MMR outcome.
|
||||
"""
|
||||
i_vals = [0, 4, 5, 13]
|
||||
N_val = 20
|
||||
await store_parseremb.aadd_texts(
|
||||
[self._v_from_i(i, N_val) for i in i_vals],
|
||||
metadatas=[{"i": i} for i in i_vals],
|
||||
)
|
||||
res1 = await store_parseremb.amax_marginal_relevance_search(
|
||||
self._v_from_i(3, N_val),
|
||||
k=2,
|
||||
fetch_k=3,
|
||||
)
|
||||
@ -381,6 +560,25 @@ class TestAstraDB:
|
||||
sco_near, sco_far = scores
|
||||
assert abs(1 - sco_near) < 0.001 and abs(sco_far) < 0.001
|
||||
|
||||
async def test_astradb_vectorstore_similarity_scale_async(
|
||||
self, store_parseremb: AstraDB
|
||||
) -> None:
|
||||
"""Scale of the similarity scores."""
|
||||
await store_parseremb.aadd_texts(
|
||||
texts=[
|
||||
json.dumps([1, 1]),
|
||||
json.dumps([-1, -1]),
|
||||
],
|
||||
ids=["near", "far"],
|
||||
)
|
||||
res1 = await store_parseremb.asimilarity_search_with_score(
|
||||
json.dumps([0.5, 0.5]),
|
||||
k=2,
|
||||
)
|
||||
scores = [sco for _, sco in res1]
|
||||
sco_near, sco_far = scores
|
||||
assert abs(1 - sco_near) < 0.001 and abs(sco_far) < 0.001
|
||||
|
||||
def test_astradb_vectorstore_massive_delete(self, store_someemb: AstraDB) -> None:
|
||||
"""Larger-scale bulk deletes."""
|
||||
M = 50
|
||||
@ -458,6 +656,40 @@ class TestAstraDB:
|
||||
finally:
|
||||
v_store.delete_collection()
|
||||
|
||||
async def test_astradb_vectorstore_custom_params_async(self) -> None:
|
||||
"""Custom batch size and concurrency params."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
v_store = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_c_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
batch_size=17,
|
||||
bulk_insert_batch_concurrency=13,
|
||||
bulk_insert_overwrite_concurrency=7,
|
||||
bulk_delete_concurrency=19,
|
||||
)
|
||||
try:
|
||||
# add_texts
|
||||
N = 50
|
||||
texts = [str(i + 1 / 7.0) for i in range(N)]
|
||||
ids = ["doc_%i" % i for i in range(N)]
|
||||
await v_store.aadd_texts(texts=texts, ids=ids)
|
||||
await v_store.aadd_texts(
|
||||
texts=texts,
|
||||
ids=ids,
|
||||
batch_size=19,
|
||||
batch_concurrency=7,
|
||||
overwrite_concurrency=13,
|
||||
)
|
||||
#
|
||||
await v_store.adelete(ids[: N // 2])
|
||||
await v_store.adelete(ids[N // 2 :], concurrency=23)
|
||||
#
|
||||
finally:
|
||||
await v_store.adelete_collection()
|
||||
|
||||
def test_astradb_vectorstore_metrics(self) -> None:
|
||||
"""
|
||||
Different choices of similarity metric.
|
||||
|
Loading…
Reference in New Issue
Block a user