mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 19:39:58 +00:00
community: Add async methods to AstraDBCache (#17415)
Adds async methods to AstraDBCache
This commit is contained in:
committed by
GitHub
parent
e438fe6be9
commit
ca2d4078f3
@@ -12,9 +12,12 @@ Required to run this test:
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Iterator
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
import pytest
|
||||
from langchain_community.utilities.astradb import SetupMode
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.language_models import LLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
from langchain.cache import AstraDBCache, AstraDBSemanticCache
|
||||
@@ -41,7 +44,22 @@ def astradb_cache() -> Iterator[AstraDBCache]:
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield cache
|
||||
cache.astra_db.delete_collection("lc_integration_test_cache")
|
||||
cache.collection.astra_db.delete_collection("lc_integration_test_cache")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_astradb_cache() -> AsyncIterator[AstraDBCache]:
|
||||
cache = AstraDBCache(
|
||||
collection_name="lc_integration_test_cache_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield cache
|
||||
await cache.async_collection.astra_db.delete_collection(
|
||||
"lc_integration_test_cache_async"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -55,46 +73,87 @@ def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]:
|
||||
embedding=fake_embe,
|
||||
)
|
||||
yield sem_cache
|
||||
sem_cache.astra_db.delete_collection("lc_integration_test_cache")
|
||||
sem_cache.collection.astra_db.delete_collection("lc_integration_test_sem_cache")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_astradb_semantic_cache() -> AsyncIterator[AstraDBSemanticCache]:
|
||||
fake_embe = FakeEmbeddings()
|
||||
sem_cache = AstraDBSemanticCache(
|
||||
collection_name="lc_integration_test_sem_cache_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
embedding=fake_embe,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield sem_cache
|
||||
sem_cache.collection.astra_db.delete_collection(
|
||||
"lc_integration_test_sem_cache_async"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
class TestAstraDBCaches:
|
||||
def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None:
|
||||
set_llm_cache(astradb_cache)
|
||||
self.do_cache_test(FakeLLM(), astradb_cache, "foo")
|
||||
|
||||
async def test_astradb_cache_async(self, async_astradb_cache: AstraDBCache) -> None:
|
||||
await self.ado_cache_test(FakeLLM(), async_astradb_cache, "foo")
|
||||
|
||||
def test_astradb_semantic_cache(
|
||||
self, astradb_semantic_cache: AstraDBSemanticCache
|
||||
) -> None:
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
print(output) # noqa: T201
|
||||
expected_output = LLMResult(
|
||||
self.do_cache_test(llm, astradb_semantic_cache, "bar")
|
||||
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
print(expected_output) # noqa: T201
|
||||
assert output == expected_output
|
||||
astradb_cache.clear()
|
||||
astradb_semantic_cache.clear()
|
||||
|
||||
def test_cassandra_semantic_cache(
|
||||
self, astradb_semantic_cache: AstraDBSemanticCache
|
||||
async def test_astradb_semantic_cache_async(
|
||||
self, async_astradb_semantic_cache: AstraDBSemanticCache
|
||||
) -> None:
|
||||
set_llm_cache(astradb_semantic_cache)
|
||||
llm = FakeLLM()
|
||||
await self.ado_cache_test(llm, async_astradb_semantic_cache, "bar")
|
||||
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
await async_astradb_semantic_cache.aclear()
|
||||
|
||||
@staticmethod
|
||||
def do_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
|
||||
set_llm_cache(cache)
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["bar"]) # same embedding as 'foo'
|
||||
output = llm.generate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
astradb_semantic_cache.clear()
|
||||
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != expected_output
|
||||
astradb_semantic_cache.clear()
|
||||
cache.clear()
|
||||
|
||||
@staticmethod
|
||||
async def ado_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
|
||||
set_llm_cache(cache)
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||
output = await llm.agenerate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
await cache.aclear()
|
||||
|
Reference in New Issue
Block a user