community: Add async methods to AstraDBCache (#17415)

Adds async methods to AstraDBCache
This commit is contained in:
Christophe Bornet
2024-02-15 05:10:08 +01:00
committed by GitHub
parent e438fe6be9
commit ca2d4078f3
6 changed files with 465 additions and 106 deletions

View File

@@ -12,9 +12,12 @@ Required to run this test:
"""
import os
from typing import Iterator
from typing import AsyncIterator, Iterator
import pytest
from langchain_community.utilities.astradb import SetupMode
from langchain_core.caches import BaseCache
from langchain_core.language_models import LLM
from langchain_core.outputs import Generation, LLMResult
from langchain.cache import AstraDBCache, AstraDBSemanticCache
@@ -41,7 +44,22 @@ def astradb_cache() -> Iterator[AstraDBCache]:
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield cache
cache.astra_db.delete_collection("lc_integration_test_cache")
cache.collection.astra_db.delete_collection("lc_integration_test_cache")
@pytest.fixture
async def async_astradb_cache() -> AsyncIterator[AstraDBCache]:
cache = AstraDBCache(
collection_name="lc_integration_test_cache_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
setup_mode=SetupMode.ASYNC,
)
yield cache
await cache.async_collection.astra_db.delete_collection(
"lc_integration_test_cache_async"
)
@pytest.fixture(scope="module")
@@ -55,46 +73,87 @@ def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]:
embedding=fake_embe,
)
yield sem_cache
sem_cache.astra_db.delete_collection("lc_integration_test_cache")
sem_cache.collection.astra_db.delete_collection("lc_integration_test_sem_cache")
@pytest.fixture
async def async_astradb_semantic_cache() -> AsyncIterator[AstraDBSemanticCache]:
fake_embe = FakeEmbeddings()
sem_cache = AstraDBSemanticCache(
collection_name="lc_integration_test_sem_cache_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
embedding=fake_embe,
setup_mode=SetupMode.ASYNC,
)
yield sem_cache
sem_cache.collection.astra_db.delete_collection(
"lc_integration_test_sem_cache_async"
)
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDBCaches:
def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None:
set_llm_cache(astradb_cache)
self.do_cache_test(FakeLLM(), astradb_cache, "foo")
async def test_astradb_cache_async(self, async_astradb_cache: AstraDBCache) -> None:
await self.ado_cache_test(FakeLLM(), async_astradb_cache, "foo")
def test_astradb_semantic_cache(
self, astradb_semantic_cache: AstraDBSemanticCache
) -> None:
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output) # noqa: T201
expected_output = LLMResult(
self.do_cache_test(llm, astradb_semantic_cache, "bar")
output = llm.generate(["bar"]) # 'fizz' is erased away now
assert output != LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
print(expected_output) # noqa: T201
assert output == expected_output
astradb_cache.clear()
astradb_semantic_cache.clear()
def test_cassandra_semantic_cache(
self, astradb_semantic_cache: AstraDBSemanticCache
async def test_astradb_semantic_cache_async(
self, async_astradb_semantic_cache: AstraDBSemanticCache
) -> None:
set_llm_cache(astradb_semantic_cache)
llm = FakeLLM()
await self.ado_cache_test(llm, async_astradb_semantic_cache, "bar")
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
assert output != LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
await async_astradb_semantic_cache.aclear()
@staticmethod
def do_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
set_llm_cache(cache)
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["bar"]) # same embedding as 'foo'
output = llm.generate([prompt])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
astradb_semantic_cache.clear()
output = llm.generate(["bar"]) # 'fizz' is erased away now
assert output != expected_output
astradb_semantic_cache.clear()
cache.clear()
@staticmethod
async def ado_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
set_llm_cache(cache)
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
output = await llm.agenerate([prompt])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
await cache.aclear()