diff --git a/libs/langchain/langchain/cache.py b/libs/langchain/langchain/cache.py index 32eb310d44e..1d238086bff 100644 --- a/libs/langchain/langchain/cache.py +++ b/libs/langchain/langchain/cache.py @@ -80,6 +80,8 @@ def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str: Returns: str: Json representing a list of generations. + + Warning: would not work well with arbitrary subclasses of `Generation` """ return json.dumps([generation.dict() for generation in generations]) @@ -95,6 +97,8 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE: Returns: RETURN_VAL_TYPE: A list of generations. + + Warning: would not work well with arbitrary subclasses of `Generation` """ try: results = json.loads(generations_json) @@ -105,6 +109,65 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE: ) +def _dumps_generations(generations: RETURN_VAL_TYPE) -> str: + """ + Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation` + + Args: + generations (RETURN_VAL_TYPE): A list of language model generations. + + Returns: + str: a single string representing a list of generations. + + This function (+ its counterpart `_loads_generations`) rely on + the dumps/loads pair with Reviver, so are able to deal + with all subclasses of Generation. + + Each item in the list can be `dumps`ed to a string, + then we make the whole list of strings into a json-dumped. + """ + return json.dumps([dumps(_item) for _item in generations]) + + +def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]: + """ + Deserialization of a string into a generic RETURN_VAL_TYPE + (i.e. a sequence of `Generation`). + + See `_dumps_generations`, the inverse of this function. + + Args: + generations_str (str): A string representing a list of generations. + + Compatible with the legacy cache-blob format + Does not raise exceptions for malformed entries, just logs a warning + and returns none: the caller should be prepared for such a cache miss. + + Returns: + RETURN_VAL_TYPE: A list of generations. + """ + try: + generations = [loads(_item_str) for _item_str in json.loads(generations_str)] + return generations + except (json.JSONDecodeError, TypeError): + # deferring the (soft) handling to after the legacy-format attempt + pass + + try: + gen_dicts = json.loads(generations_str) + # not relying on `_load_generations_from_json` (which could disappear): + generations = [Generation(**generation_dict) for generation_dict in gen_dicts] + logger.warning( + f"Legacy 'Generation' cached blob encountered: '{generations_str}'" + ) + return generations + except (json.JSONDecodeError, TypeError): + logger.warning( + f"Malformed/unparsable cached blob encountered: '{generations_str}'" + ) + return None + + class InMemoryCache(BaseCache): """Cache that stores things in memory.""" @@ -733,10 +796,11 @@ class CassandraCache(BaseCache): def __init__( self, - session: CassandraSession, - keyspace: str, + session: Optional[CassandraSession] = None, + keyspace: Optional[str] = None, table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME, ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS, + skip_provisioning: bool = False, ): """ Initialize with a ready session and a keyspace name. @@ -767,6 +831,7 @@ class CassandraCache(BaseCache): keys=["llm_string", "prompt"], primary_key_type=["TEXT", "TEXT"], ttl_seconds=self.ttl_seconds, + skip_provisioning=skip_provisioning, ) def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: @@ -775,14 +840,19 @@ class CassandraCache(BaseCache): llm_string=_hash(llm_string), prompt=_hash(prompt), ) - if item: - return _load_generations_from_json(item["body_blob"]) + if item is not None: + generations = _loads_generations(item["body_blob"]) + # this protects against malformed cached items: + if generations is not None: + return generations + else: + return None else: return None def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update cache based on prompt and llm_string.""" - blob = _dump_generations_to_json(return_val) + blob = _dumps_generations(return_val) self.kv_cache.put( llm_string=_hash(llm_string), prompt=_hash(prompt), @@ -836,13 +906,14 @@ class CassandraSemanticCache(BaseCache): def __init__( self, - session: CassandraSession, - keyspace: str, + session: Optional[CassandraSession], + keyspace: Optional[str], embedding: Embeddings, table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME, distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC, score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD, ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS, + skip_provisioning: bool = False, ): """ Initialize the cache with all relevant parameters. @@ -897,6 +968,7 @@ class CassandraSemanticCache(BaseCache): vector_dimension=self.embedding_dimension, ttl_seconds=self.ttl_seconds, metadata_indexing=("allow", {"_llm_string_hash"}), + skip_provisioning=skip_provisioning, ) def _get_embedding_dimension(self) -> int: @@ -906,7 +978,7 @@ class CassandraSemanticCache(BaseCache): """Update cache based on prompt and llm_string.""" embedding_vector = self._get_embedding(text=prompt) llm_string_hash = _hash(llm_string) - body = _dump_generations_to_json(return_val) + body = _dumps_generations(return_val) metadata = { "_prompt": prompt, "_llm_string_hash": llm_string_hash, @@ -947,11 +1019,15 @@ class CassandraSemanticCache(BaseCache): ) if hits: hit = hits[0] - generations_str = hit["body_blob"] - return ( - hit["row_id"], - _load_generations_from_json(generations_str), - ) + generations = _loads_generations(hit["body_blob"]) + if generations is not None: + # this protects against malformed cached items: + return ( + hit["row_id"], + generations, + ) + else: + return None else: return None diff --git a/libs/langchain/tests/integration_tests/cache/test_cassandra.py b/libs/langchain/tests/integration_tests/cache/test_cassandra.py index 760458f12b0..9dd2af8d6e0 100644 --- a/libs/langchain/tests/integration_tests/cache/test_cassandra.py +++ b/libs/langchain/tests/integration_tests/cache/test_cassandra.py @@ -38,7 +38,7 @@ def cassandra_connection() -> Iterator[Tuple[Any, str]]: def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None: session, keyspace = cassandra_connection - cache = CassandraCache(session, keyspace) + cache = CassandraCache(session=session, keyspace=keyspace) langchain.llm_cache = cache llm = FakeLLM() params = llm.dict() @@ -58,7 +58,7 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None: def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None: session, keyspace = cassandra_connection - cache = CassandraCache(session, keyspace, ttl_seconds=2) + cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2) langchain.llm_cache = cache llm = FakeLLM() params = llm.dict() @@ -80,7 +80,11 @@ def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None: def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None: session, keyspace = cassandra_connection - sem_cache = CassandraSemanticCache(session, keyspace, embedding=FakeEmbeddings()) + sem_cache = CassandraSemanticCache( + session=session, + keyspace=keyspace, + embedding=FakeEmbeddings(), + ) langchain.llm_cache = sem_cache llm = FakeLLM() params = llm.dict()