mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 18:38:48 +00:00
community[minor]: Add async methods to CassandraCache and CassandraSemanticCache (#20654)
This commit is contained in:
committed by
GitHub
parent
d6e9bd3011
commit
5c77f45b06
@@ -53,6 +53,7 @@ from sqlalchemy.engine import Row
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain_community.utilities.cassandra import SetupMode as CassandraSetupMode
|
||||
from langchain_community.vectorstores.azure_cosmos_db import (
|
||||
CosmosDBSimilarityType,
|
||||
CosmosDBVectorSearchType,
|
||||
@@ -63,7 +64,7 @@ try:
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core._api.deprecation import deprecated, warn_deprecated
|
||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts
|
||||
@@ -73,7 +74,9 @@ from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
from langchain_community.utilities.astradb import (
|
||||
SetupMode,
|
||||
SetupMode as AstraSetupMode,
|
||||
)
|
||||
from langchain_community.utilities.astradb import (
|
||||
_AstraDBCollectionEnvironment,
|
||||
)
|
||||
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
|
||||
@@ -1056,6 +1059,7 @@ class CassandraCache(BaseCache):
|
||||
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
|
||||
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
|
||||
skip_provisioning: bool = False,
|
||||
setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC,
|
||||
):
|
||||
"""
|
||||
Initialize with a ready session and a keyspace name.
|
||||
@@ -1066,6 +1070,10 @@ class CassandraCache(BaseCache):
|
||||
ttl_seconds (optional int): time-to-live for cache entries
|
||||
(default: None, i.e. forever)
|
||||
"""
|
||||
if skip_provisioning:
|
||||
warn_deprecated(
|
||||
"0.0.33", alternative="Use setup_mode=CassandraSetupMode.OFF instead."
|
||||
)
|
||||
try:
|
||||
from cassio.table import ElasticCassandraTable
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
@@ -1079,6 +1087,10 @@ class CassandraCache(BaseCache):
|
||||
self.table_name = table_name
|
||||
self.ttl_seconds = ttl_seconds
|
||||
|
||||
kwargs = {}
|
||||
if setup_mode == CassandraSetupMode.ASYNC:
|
||||
kwargs["async_setup"] = True
|
||||
|
||||
self.kv_cache = ElasticCassandraTable(
|
||||
session=self.session,
|
||||
keyspace=self.keyspace,
|
||||
@@ -1086,27 +1098,31 @@ class CassandraCache(BaseCache):
|
||||
keys=["llm_string", "prompt"],
|
||||
primary_key_type=["TEXT", "TEXT"],
|
||||
ttl_seconds=self.ttl_seconds,
|
||||
skip_provisioning=skip_provisioning,
|
||||
skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
item = self.kv_cache.get(
|
||||
llm_string=_hash(llm_string),
|
||||
prompt=_hash(prompt),
|
||||
)
|
||||
if item is not None:
|
||||
generations = _loads_generations(item["body_blob"])
|
||||
# this protects against malformed cached items:
|
||||
if generations is not None:
|
||||
return generations
|
||||
else:
|
||||
return None
|
||||
return _loads_generations(item["body_blob"])
|
||||
else:
|
||||
return None
|
||||
|
||||
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
item = await self.kv_cache.aget(
|
||||
llm_string=_hash(llm_string),
|
||||
prompt=_hash(prompt),
|
||||
)
|
||||
if item is not None:
|
||||
return _loads_generations(item["body_blob"])
|
||||
else:
|
||||
return None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
blob = _dumps_generations(return_val)
|
||||
self.kv_cache.put(
|
||||
llm_string=_hash(llm_string),
|
||||
@@ -1114,6 +1130,16 @@ class CassandraCache(BaseCache):
|
||||
body_blob=blob,
|
||||
)
|
||||
|
||||
async def aupdate(
|
||||
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
||||
) -> None:
|
||||
blob = _dumps_generations(return_val)
|
||||
await self.kv_cache.aput(
|
||||
llm_string=_hash(llm_string),
|
||||
prompt=_hash(prompt),
|
||||
body_blob=blob,
|
||||
)
|
||||
|
||||
def delete_through_llm(
|
||||
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||
) -> None:
|
||||
@@ -1139,6 +1165,10 @@ class CassandraCache(BaseCache):
|
||||
"""Clear cache. This is for all LLMs at once."""
|
||||
self.kv_cache.clear()
|
||||
|
||||
async def aclear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache. This is for all LLMs at once."""
|
||||
await self.kv_cache.aclear()
|
||||
|
||||
|
||||
CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot"
|
||||
CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85
|
||||
@@ -1170,6 +1200,7 @@ class CassandraSemanticCache(BaseCache):
|
||||
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
|
||||
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
|
||||
skip_provisioning: bool = False,
|
||||
setup_mode: CassandraSetupMode = CassandraSetupMode.SYNC,
|
||||
):
|
||||
"""
|
||||
Initialize the cache with all relevant parameters.
|
||||
@@ -1189,6 +1220,10 @@ class CassandraSemanticCache(BaseCache):
|
||||
The default score threshold is tuned to the default metric.
|
||||
Tune it carefully yourself if switching to another distance metric.
|
||||
"""
|
||||
if skip_provisioning:
|
||||
warn_deprecated(
|
||||
"0.0.33", alternative="Use setup_mode=CassandraSetupMode.OFF instead."
|
||||
)
|
||||
try:
|
||||
from cassio.table import MetadataVectorCassandraTable
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
@@ -1214,24 +1249,42 @@ class CassandraSemanticCache(BaseCache):
|
||||
return self.embedding.embed_query(text=text)
|
||||
|
||||
self._get_embedding = _cache_embedding
|
||||
self.embedding_dimension = self._get_embedding_dimension()
|
||||
|
||||
@_async_lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
|
||||
async def _acache_embedding(text: str) -> List[float]:
|
||||
return await self.embedding.aembed_query(text=text)
|
||||
|
||||
self._aget_embedding = _acache_embedding
|
||||
|
||||
embedding_dimension: Union[int, Awaitable[int], None] = None
|
||||
if setup_mode == CassandraSetupMode.ASYNC:
|
||||
embedding_dimension = self._aget_embedding_dimension()
|
||||
elif setup_mode == CassandraSetupMode.SYNC:
|
||||
embedding_dimension = self._get_embedding_dimension()
|
||||
|
||||
kwargs = {}
|
||||
if setup_mode == CassandraSetupMode.ASYNC:
|
||||
kwargs["async_setup"] = True
|
||||
|
||||
self.table = MetadataVectorCassandraTable(
|
||||
session=self.session,
|
||||
keyspace=self.keyspace,
|
||||
table=self.table_name,
|
||||
primary_key_type=["TEXT"],
|
||||
vector_dimension=self.embedding_dimension,
|
||||
vector_dimension=embedding_dimension,
|
||||
ttl_seconds=self.ttl_seconds,
|
||||
metadata_indexing=("allow", {"_llm_string_hash"}),
|
||||
skip_provisioning=skip_provisioning,
|
||||
skip_provisioning=skip_provisioning or setup_mode == CassandraSetupMode.OFF,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_embedding_dimension(self) -> int:
|
||||
return len(self._get_embedding(text="This is a sample sentence."))
|
||||
|
||||
async def _aget_embedding_dimension(self) -> int:
|
||||
return len(await self._aget_embedding(text="This is a sample sentence."))
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
embedding_vector = self._get_embedding(text=prompt)
|
||||
llm_string_hash = _hash(llm_string)
|
||||
body = _dumps_generations(return_val)
|
||||
@@ -1240,7 +1293,7 @@ class CassandraSemanticCache(BaseCache):
|
||||
"_llm_string_hash": llm_string_hash,
|
||||
}
|
||||
row_id = f"{_hash(prompt)}-{llm_string_hash}"
|
||||
#
|
||||
|
||||
self.table.put(
|
||||
body_blob=body,
|
||||
vector=embedding_vector,
|
||||
@@ -1248,14 +1301,39 @@ class CassandraSemanticCache(BaseCache):
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def aupdate(
|
||||
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
||||
) -> None:
|
||||
embedding_vector = await self._aget_embedding(text=prompt)
|
||||
llm_string_hash = _hash(llm_string)
|
||||
body = _dumps_generations(return_val)
|
||||
metadata = {
|
||||
"_prompt": prompt,
|
||||
"_llm_string_hash": llm_string_hash,
|
||||
}
|
||||
row_id = f"{_hash(prompt)}-{llm_string_hash}"
|
||||
|
||||
await self.table.aput(
|
||||
body_blob=body,
|
||||
vector=embedding_vector,
|
||||
row_id=row_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
hit_with_id = self.lookup_with_id(prompt, llm_string)
|
||||
if hit_with_id is not None:
|
||||
return hit_with_id[1]
|
||||
else:
|
||||
return None
|
||||
|
||||
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
hit_with_id = await self.alookup_with_id(prompt, llm_string)
|
||||
if hit_with_id is not None:
|
||||
return hit_with_id[1]
|
||||
else:
|
||||
return None
|
||||
|
||||
def lookup_with_id(
|
||||
self, prompt: str, llm_string: str
|
||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||
@@ -1287,6 +1365,37 @@ class CassandraSemanticCache(BaseCache):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def alookup_with_id(
|
||||
self, prompt: str, llm_string: str
|
||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||
"""
|
||||
Look up based on prompt and llm_string.
|
||||
If there are hits, return (document_id, cached_entry)
|
||||
"""
|
||||
prompt_embedding: List[float] = await self._aget_embedding(text=prompt)
|
||||
hits = list(
|
||||
await self.table.ametric_ann_search(
|
||||
vector=prompt_embedding,
|
||||
metadata={"_llm_string_hash": _hash(llm_string)},
|
||||
n=1,
|
||||
metric=self.distance_metric,
|
||||
metric_threshold=self.score_threshold,
|
||||
)
|
||||
)
|
||||
if hits:
|
||||
hit = hits[0]
|
||||
generations = _loads_generations(hit["body_blob"])
|
||||
if generations is not None:
|
||||
# this protects against malformed cached items:
|
||||
return (
|
||||
hit["row_id"],
|
||||
generations,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def lookup_with_id_through_llm(
|
||||
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||
@@ -1296,6 +1405,17 @@ class CassandraSemanticCache(BaseCache):
|
||||
)[1]
|
||||
return self.lookup_with_id(prompt, llm_string=llm_string)
|
||||
|
||||
async def alookup_with_id_through_llm(
|
||||
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||
llm_string = (
|
||||
await aget_prompts(
|
||||
{**llm.dict(), **{"stop": stop}},
|
||||
[],
|
||||
)
|
||||
)[1]
|
||||
return await self.alookup_with_id(prompt, llm_string=llm_string)
|
||||
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
"""
|
||||
Given this is a "similarity search" cache, an invalidation pattern
|
||||
@@ -1304,10 +1424,22 @@ class CassandraSemanticCache(BaseCache):
|
||||
"""
|
||||
self.table.delete(row_id=document_id)
|
||||
|
||||
async def adelete_by_document_id(self, document_id: str) -> None:
|
||||
"""
|
||||
Given this is a "similarity search" cache, an invalidation pattern
|
||||
that makes sense is first a lookup to get an ID, and then deleting
|
||||
with that ID. This is for the second step.
|
||||
"""
|
||||
await self.table.adelete(row_id=document_id)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear the *whole* semantic cache."""
|
||||
self.table.clear()
|
||||
|
||||
async def aclear(self, **kwargs: Any) -> None:
|
||||
"""Clear the *whole* semantic cache."""
|
||||
await self.table.aclear()
|
||||
|
||||
|
||||
class FullMd5LLMCache(Base): # type: ignore
|
||||
"""SQLite table for full LLM Cache (all generations)."""
|
||||
@@ -1412,7 +1544,7 @@ class AstraDBCache(BaseCache):
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
setup_mode: AstraSetupMode = AstraSetupMode.SYNC,
|
||||
):
|
||||
"""
|
||||
Cache that uses Astra DB as a backend.
|
||||
@@ -1612,7 +1744,7 @@ class AstraDBSemanticCache(BaseCache):
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
setup_mode: AstraSetupMode = AstraSetupMode.SYNC,
|
||||
pre_delete_collection: bool = False,
|
||||
embedding: Embeddings,
|
||||
metric: Optional[str] = None,
|
||||
@@ -1675,9 +1807,9 @@ class AstraDBSemanticCache(BaseCache):
|
||||
self._aget_embedding = _acache_embedding
|
||||
|
||||
embedding_dimension: Union[int, Awaitable[int], None] = None
|
||||
if setup_mode == SetupMode.ASYNC:
|
||||
if setup_mode == AstraSetupMode.ASYNC:
|
||||
embedding_dimension = self._aget_embedding_dimension()
|
||||
elif setup_mode == SetupMode.SYNC:
|
||||
elif setup_mode == AstraSetupMode.SYNC:
|
||||
embedding_dimension = self._get_embedding_dimension()
|
||||
|
||||
self.astra_env = _AstraDBCollectionEnvironment(
|
||||
|
@@ -93,13 +93,47 @@ class BaseCache(ABC):
|
||||
"""Clear cache that can take additional keyword arguments."""
|
||||
|
||||
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Async version of lookup."""
|
||||
"""Look up based on prompt and llm_string.
|
||||
|
||||
A cache implementation is expected to generate a key from the 2-tuple
|
||||
of prompt and llm_string (e.g., by concatenating them with a delimiter).
|
||||
|
||||
Args:
|
||||
prompt: a string representation of the prompt.
|
||||
In the case of a Chat model, the prompt is a non-trivial
|
||||
serialization of the prompt into the language model.
|
||||
llm_string: A string representation of the LLM configuration.
|
||||
This is used to capture the invocation parameters of the LLM
|
||||
(e.g., model name, temperature, stop tokens, max tokens, etc.).
|
||||
These invocation parameters are serialized into a string
|
||||
representation.
|
||||
|
||||
Returns:
|
||||
On a cache miss, return None. On a cache hit, return the cached value.
|
||||
The cached value is a list of Generations (or subclasses).
|
||||
"""
|
||||
return await run_in_executor(None, self.lookup, prompt, llm_string)
|
||||
|
||||
async def aupdate(
|
||||
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
||||
) -> None:
|
||||
"""Async version of aupdate."""
|
||||
"""Update cache based on prompt and llm_string.
|
||||
|
||||
The prompt and llm_string are used to generate a key for the cache.
|
||||
The key should match that of the look up method.
|
||||
|
||||
Args:
|
||||
prompt: a string representation of the prompt.
|
||||
In the case of a Chat model, the prompt is a non-trivial
|
||||
serialization of the prompt into the language model.
|
||||
llm_string: A string representation of the LLM configuration.
|
||||
This is used to capture the invocation parameters of the LLM
|
||||
(e.g., model name, temperature, stop tokens, max tokens, etc.).
|
||||
These invocation parameters are serialized into a string
|
||||
representation.
|
||||
return_val: The value to be cached. The value is a list of Generations
|
||||
(or subclasses).
|
||||
"""
|
||||
return await run_in_executor(None, self.update, prompt, llm_string, return_val)
|
||||
|
||||
async def aclear(self, **kwargs: Any) -> None:
|
||||
|
@@ -1,10 +1,11 @@
|
||||
"""Test Cassandra caches. Requires a running vector-capable Cassandra cluster."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Iterator, Tuple
|
||||
|
||||
import pytest
|
||||
from langchain_community.utilities.cassandra import SetupMode
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
from langchain.cache import CassandraCache, CassandraSemanticCache
|
||||
@@ -47,16 +48,34 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
print(output) # noqa: T201
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
print(expected_output) # noqa: T201
|
||||
assert output == expected_output
|
||||
cache.clear()
|
||||
|
||||
|
||||
async def test_cassandra_cache_async(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(
|
||||
session=session, keyspace=keyspace, setup_mode=SetupMode.ASYNC
|
||||
)
|
||||
set_llm_cache(cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||
output = await llm.agenerate(["foo"])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
await cache.aclear()
|
||||
|
||||
|
||||
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
|
||||
@@ -79,6 +98,30 @@ def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
cache.clear()
|
||||
|
||||
|
||||
async def test_cassandra_cache_ttl_async(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(
|
||||
session=session, keyspace=keyspace, ttl_seconds=2, setup_mode=SetupMode.ASYNC
|
||||
)
|
||||
set_llm_cache(cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
output = await llm.agenerate(["foo"])
|
||||
assert output == expected_output
|
||||
await asyncio.sleep(2.5)
|
||||
# entry has expired away.
|
||||
output = await llm.agenerate(["foo"])
|
||||
assert output != expected_output
|
||||
await cache.aclear()
|
||||
|
||||
|
||||
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
sem_cache = CassandraSemanticCache(
|
||||
@@ -103,3 +146,32 @@ def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None
|
||||
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != expected_output
|
||||
sem_cache.clear()
|
||||
|
||||
|
||||
async def test_cassandra_semantic_cache_async(
|
||||
cassandra_connection: Tuple[Any, str],
|
||||
) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
sem_cache = CassandraSemanticCache(
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
embedding=FakeEmbeddings(),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
set_llm_cache(sem_cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||
output = await llm.agenerate(["bar"]) # same embedding as 'foo'
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
await sem_cache.aclear()
|
||||
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != expected_output
|
||||
await sem_cache.aclear()
|
||||
|
Reference in New Issue
Block a user