mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
CassandraCache and CassandraSemanticCache can handle any "Generation" (#10563)
Hello, this PR improves coverage for caching by the two Cassandra-related caches (i.e. exact-match and semantic alike) by switching to the more general `dumps`/`loads` serdes utilities. This enables cache usage within e.g. `ChatOpenAI` contexts (which need to store lists of `ChatGeneration` instead of `Generation`s), which was not possible as long as the cache classes were relying on the legacy `_dump_generations_to_json` and `_load_generations_from_json`). Additionally, a slightly different init signature is introduced for the cache objects: - named parameters required for init, to pave the way for easier changes in the future connect-to-db flow (and tests adjusted accordingly) - added a `skip_provisioning` optional passthrough parameter for use cases where the user knows the underlying DB table, etc already exist. Thank you for a review!
This commit is contained in:
parent
e1e01d6586
commit
49b65a1b57
@ -80,6 +80,8 @@ def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
|
||||
|
||||
Returns:
|
||||
str: Json representing a list of generations.
|
||||
|
||||
Warning: would not work well with arbitrary subclasses of `Generation`
|
||||
"""
|
||||
return json.dumps([generation.dict() for generation in generations])
|
||||
|
||||
@ -95,6 +97,8 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
|
||||
|
||||
Returns:
|
||||
RETURN_VAL_TYPE: A list of generations.
|
||||
|
||||
Warning: would not work well with arbitrary subclasses of `Generation`
|
||||
"""
|
||||
try:
|
||||
results = json.loads(generations_json)
|
||||
@ -105,6 +109,65 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
|
||||
)
|
||||
|
||||
|
||||
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
|
||||
"""
|
||||
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
|
||||
|
||||
Args:
|
||||
generations (RETURN_VAL_TYPE): A list of language model generations.
|
||||
|
||||
Returns:
|
||||
str: a single string representing a list of generations.
|
||||
|
||||
This function (+ its counterpart `_loads_generations`) rely on
|
||||
the dumps/loads pair with Reviver, so are able to deal
|
||||
with all subclasses of Generation.
|
||||
|
||||
Each item in the list can be `dumps`ed to a string,
|
||||
then we make the whole list of strings into a json-dumped.
|
||||
"""
|
||||
return json.dumps([dumps(_item) for _item in generations])
|
||||
|
||||
|
||||
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
|
||||
"""
|
||||
Deserialization of a string into a generic RETURN_VAL_TYPE
|
||||
(i.e. a sequence of `Generation`).
|
||||
|
||||
See `_dumps_generations`, the inverse of this function.
|
||||
|
||||
Args:
|
||||
generations_str (str): A string representing a list of generations.
|
||||
|
||||
Compatible with the legacy cache-blob format
|
||||
Does not raise exceptions for malformed entries, just logs a warning
|
||||
and returns none: the caller should be prepared for such a cache miss.
|
||||
|
||||
Returns:
|
||||
RETURN_VAL_TYPE: A list of generations.
|
||||
"""
|
||||
try:
|
||||
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
|
||||
return generations
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# deferring the (soft) handling to after the legacy-format attempt
|
||||
pass
|
||||
|
||||
try:
|
||||
gen_dicts = json.loads(generations_str)
|
||||
# not relying on `_load_generations_from_json` (which could disappear):
|
||||
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
|
||||
logger.warning(
|
||||
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
|
||||
)
|
||||
return generations
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
"""Cache that stores things in memory."""
|
||||
|
||||
@ -733,10 +796,11 @@ class CassandraCache(BaseCache):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: CassandraSession,
|
||||
keyspace: str,
|
||||
session: Optional[CassandraSession] = None,
|
||||
keyspace: Optional[str] = None,
|
||||
table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
|
||||
ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
|
||||
skip_provisioning: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize with a ready session and a keyspace name.
|
||||
@ -767,6 +831,7 @@ class CassandraCache(BaseCache):
|
||||
keys=["llm_string", "prompt"],
|
||||
primary_key_type=["TEXT", "TEXT"],
|
||||
ttl_seconds=self.ttl_seconds,
|
||||
skip_provisioning=skip_provisioning,
|
||||
)
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
@ -775,14 +840,19 @@ class CassandraCache(BaseCache):
|
||||
llm_string=_hash(llm_string),
|
||||
prompt=_hash(prompt),
|
||||
)
|
||||
if item:
|
||||
return _load_generations_from_json(item["body_blob"])
|
||||
if item is not None:
|
||||
generations = _loads_generations(item["body_blob"])
|
||||
# this protects against malformed cached items:
|
||||
if generations is not None:
|
||||
return generations
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
blob = _dump_generations_to_json(return_val)
|
||||
blob = _dumps_generations(return_val)
|
||||
self.kv_cache.put(
|
||||
llm_string=_hash(llm_string),
|
||||
prompt=_hash(prompt),
|
||||
@ -836,13 +906,14 @@ class CassandraSemanticCache(BaseCache):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: CassandraSession,
|
||||
keyspace: str,
|
||||
session: Optional[CassandraSession],
|
||||
keyspace: Optional[str],
|
||||
embedding: Embeddings,
|
||||
table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME,
|
||||
distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC,
|
||||
score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
|
||||
ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
|
||||
skip_provisioning: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the cache with all relevant parameters.
|
||||
@ -897,6 +968,7 @@ class CassandraSemanticCache(BaseCache):
|
||||
vector_dimension=self.embedding_dimension,
|
||||
ttl_seconds=self.ttl_seconds,
|
||||
metadata_indexing=("allow", {"_llm_string_hash"}),
|
||||
skip_provisioning=skip_provisioning,
|
||||
)
|
||||
|
||||
def _get_embedding_dimension(self) -> int:
|
||||
@ -906,7 +978,7 @@ class CassandraSemanticCache(BaseCache):
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
embedding_vector = self._get_embedding(text=prompt)
|
||||
llm_string_hash = _hash(llm_string)
|
||||
body = _dump_generations_to_json(return_val)
|
||||
body = _dumps_generations(return_val)
|
||||
metadata = {
|
||||
"_prompt": prompt,
|
||||
"_llm_string_hash": llm_string_hash,
|
||||
@ -947,11 +1019,15 @@ class CassandraSemanticCache(BaseCache):
|
||||
)
|
||||
if hits:
|
||||
hit = hits[0]
|
||||
generations_str = hit["body_blob"]
|
||||
return (
|
||||
hit["row_id"],
|
||||
_load_generations_from_json(generations_str),
|
||||
)
|
||||
generations = _loads_generations(hit["body_blob"])
|
||||
if generations is not None:
|
||||
# this protects against malformed cached items:
|
||||
return (
|
||||
hit["row_id"],
|
||||
generations,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -38,7 +38,7 @@ def cassandra_connection() -> Iterator[Tuple[Any, str]]:
|
||||
|
||||
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(session, keyspace)
|
||||
cache = CassandraCache(session=session, keyspace=keyspace)
|
||||
langchain.llm_cache = cache
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
@ -58,7 +58,7 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
|
||||
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(session, keyspace, ttl_seconds=2)
|
||||
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
|
||||
langchain.llm_cache = cache
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
@ -80,7 +80,11 @@ def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
|
||||
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
sem_cache = CassandraSemanticCache(session, keyspace, embedding=FakeEmbeddings())
|
||||
sem_cache = CassandraSemanticCache(
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
embedding=FakeEmbeddings(),
|
||||
)
|
||||
langchain.llm_cache = sem_cache
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
|
Loading…
Reference in New Issue
Block a user