mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-22 17:50:03 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			1558 lines
		
	
	
		
			57 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1558 lines
		
	
	
		
			57 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| .. warning::
 | |
|   Beta Feature!
 | |
| 
 | |
| **Cache** provides an optional caching layer for LLMs.
 | |
| 
 | |
| Cache is useful for two reasons:
 | |
| 
 | |
| - It can save you money by reducing the number of API calls you make to the LLM
 | |
|   provider if you're often requesting the same completion multiple times.
 | |
| - It can speed up your application by reducing the number of API calls you make
 | |
|   to the LLM provider.
 | |
| 
 | |
| Cache directly competes with Memory. See documentation for Pros and Cons.
 | |
| 
 | |
| **Class hierarchy:**
 | |
| 
 | |
| .. code-block::
 | |
| 
 | |
|     BaseCache --> <name>Cache  # Examples: InMemoryCache, RedisCache, GPTCache
 | |
| """
 | |
| from __future__ import annotations
 | |
| 
 | |
| import hashlib
 | |
| import inspect
 | |
| import json
 | |
| import logging
 | |
| import uuid
 | |
| import warnings
 | |
| from datetime import timedelta
 | |
| from functools import lru_cache
 | |
| from typing import (
 | |
|     TYPE_CHECKING,
 | |
|     Any,
 | |
|     Callable,
 | |
|     Dict,
 | |
|     List,
 | |
|     Optional,
 | |
|     Tuple,
 | |
|     Type,
 | |
|     Union,
 | |
|     cast,
 | |
| )
 | |
| 
 | |
| from sqlalchemy import Column, Integer, String, create_engine, select
 | |
| from sqlalchemy.engine import Row
 | |
| from sqlalchemy.engine.base import Engine
 | |
| from sqlalchemy.orm import Session
 | |
| 
 | |
| try:
 | |
|     from sqlalchemy.orm import declarative_base
 | |
| except ImportError:
 | |
|     from sqlalchemy.ext.declarative import declarative_base
 | |
| 
 | |
| from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
 | |
| from langchain_core.embeddings import Embeddings
 | |
| from langchain_core.language_models.llms import LLM, get_prompts
 | |
| from langchain_core.load.dump import dumps
 | |
| from langchain_core.load.load import loads
 | |
| from langchain_core.outputs import ChatGeneration, Generation
 | |
| from langchain_core.utils import get_from_env
 | |
| 
 | |
| from langchain_community.vectorstores.redis import Redis as RedisVectorstore
 | |
| 
 | |
| logger = logging.getLogger(__file__)
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     import momento
 | |
|     from cassandra.cluster import Session as CassandraSession
 | |
| 
 | |
| 
 | |
| def _hash(_input: str) -> str:
 | |
|     """Use a deterministic hashing approach."""
 | |
|     return hashlib.md5(_input.encode()).hexdigest()
 | |
| 
 | |
| 
 | |
| def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
 | |
|     """Dump generations to json.
 | |
| 
 | |
|     Args:
 | |
|         generations (RETURN_VAL_TYPE): A list of language model generations.
 | |
| 
 | |
|     Returns:
 | |
|         str: Json representing a list of generations.
 | |
| 
 | |
|     Warning: would not work well with arbitrary subclasses of `Generation`
 | |
|     """
 | |
|     return json.dumps([generation.dict() for generation in generations])
 | |
| 
 | |
| 
 | |
| def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
 | |
|     """Load generations from json.
 | |
| 
 | |
|     Args:
 | |
|         generations_json (str): A string of json representing a list of generations.
 | |
| 
 | |
|     Raises:
 | |
|         ValueError: Could not decode json string to list of generations.
 | |
| 
 | |
|     Returns:
 | |
|         RETURN_VAL_TYPE: A list of generations.
 | |
| 
 | |
|     Warning: would not work well with arbitrary subclasses of `Generation`
 | |
|     """
 | |
|     try:
 | |
|         results = json.loads(generations_json)
 | |
|         return [Generation(**generation_dict) for generation_dict in results]
 | |
|     except json.JSONDecodeError:
 | |
|         raise ValueError(
 | |
|             f"Could not decode json to list of generations: {generations_json}"
 | |
|         )
 | |
| 
 | |
| 
 | |
| def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
 | |
|     """
 | |
|     Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
 | |
| 
 | |
|     Args:
 | |
|         generations (RETURN_VAL_TYPE): A list of language model generations.
 | |
| 
 | |
|     Returns:
 | |
|         str: a single string representing a list of generations.
 | |
| 
 | |
|     This function (+ its counterpart `_loads_generations`) rely on
 | |
|     the dumps/loads pair with Reviver, so are able to deal
 | |
|     with all subclasses of Generation.
 | |
| 
 | |
|     Each item in the list can be `dumps`ed to a string,
 | |
|     then we make the whole list of strings into a json-dumped.
 | |
|     """
 | |
|     return json.dumps([dumps(_item) for _item in generations])
 | |
| 
 | |
| 
 | |
| def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
 | |
|     """
 | |
|     Deserialization of a string into a generic RETURN_VAL_TYPE
 | |
|     (i.e. a sequence of `Generation`).
 | |
| 
 | |
|     See `_dumps_generations`, the inverse of this function.
 | |
| 
 | |
|     Args:
 | |
|         generations_str (str): A string representing a list of generations.
 | |
| 
 | |
|     Compatible with the legacy cache-blob format
 | |
|     Does not raise exceptions for malformed entries, just logs a warning
 | |
|     and returns none: the caller should be prepared for such a cache miss.
 | |
| 
 | |
|     Returns:
 | |
|         RETURN_VAL_TYPE: A list of generations.
 | |
|     """
 | |
|     try:
 | |
|         generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
 | |
|         return generations
 | |
|     except (json.JSONDecodeError, TypeError):
 | |
|         # deferring the (soft) handling to after the legacy-format attempt
 | |
|         pass
 | |
| 
 | |
|     try:
 | |
|         gen_dicts = json.loads(generations_str)
 | |
|         # not relying on `_load_generations_from_json` (which could disappear):
 | |
|         generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
 | |
|         logger.warning(
 | |
|             f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
 | |
|         )
 | |
|         return generations
 | |
|     except (json.JSONDecodeError, TypeError):
 | |
|         logger.warning(
 | |
|             f"Malformed/unparsable cached blob encountered: '{generations_str}'"
 | |
|         )
 | |
|         return None
 | |
| 
 | |
| 
 | |
| class InMemoryCache(BaseCache):
 | |
|     """Cache that stores things in memory."""
 | |
| 
 | |
|     def __init__(self) -> None:
 | |
|         """Initialize with empty cache."""
 | |
|         self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up based on prompt and llm_string."""
 | |
|         return self._cache.get((prompt, llm_string), None)
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update cache based on prompt and llm_string."""
 | |
|         self._cache[(prompt, llm_string)] = return_val
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear cache."""
 | |
|         self._cache = {}
 | |
| 
 | |
| 
 | |
| Base = declarative_base()
 | |
| 
 | |
| 
 | |
| class FullLLMCache(Base):  # type: ignore
 | |
|     """SQLite table for full LLM Cache (all generations)."""
 | |
| 
 | |
|     __tablename__ = "full_llm_cache"
 | |
|     prompt = Column(String, primary_key=True)
 | |
|     llm = Column(String, primary_key=True)
 | |
|     idx = Column(Integer, primary_key=True)
 | |
|     response = Column(String)
 | |
| 
 | |
| 
 | |
| class SQLAlchemyCache(BaseCache):
 | |
|     """Cache that uses SQAlchemy as a backend."""
 | |
| 
 | |
|     def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache):
 | |
|         """Initialize by creating all tables."""
 | |
|         self.engine = engine
 | |
|         self.cache_schema = cache_schema
 | |
|         self.cache_schema.metadata.create_all(self.engine)
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up based on prompt and llm_string."""
 | |
|         stmt = (
 | |
|             select(self.cache_schema.response)
 | |
|             .where(self.cache_schema.prompt == prompt)  # type: ignore
 | |
|             .where(self.cache_schema.llm == llm_string)
 | |
|             .order_by(self.cache_schema.idx)
 | |
|         )
 | |
|         with Session(self.engine) as session:
 | |
|             rows = session.execute(stmt).fetchall()
 | |
|             if rows:
 | |
|                 try:
 | |
|                     return [loads(row[0]) for row in rows]
 | |
|                 except Exception:
 | |
|                     logger.warning(
 | |
|                         "Retrieving a cache value that could not be deserialized "
 | |
|                         "properly. This is likely due to the cache being in an "
 | |
|                         "older format. Please recreate your cache to avoid this "
 | |
|                         "error."
 | |
|                     )
 | |
|                     # In a previous life we stored the raw text directly
 | |
|                     # in the table, so assume it's in that format.
 | |
|                     return [Generation(text=row[0]) for row in rows]
 | |
|         return None
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update based on prompt and llm_string."""
 | |
|         items = [
 | |
|             self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i)
 | |
|             for i, gen in enumerate(return_val)
 | |
|         ]
 | |
|         with Session(self.engine) as session, session.begin():
 | |
|             for item in items:
 | |
|                 session.merge(item)
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear cache."""
 | |
|         with Session(self.engine) as session:
 | |
|             session.query(self.cache_schema).delete()
 | |
|             session.commit()
 | |
| 
 | |
| 
 | |
| class SQLiteCache(SQLAlchemyCache):
 | |
|     """Cache that uses SQLite as a backend."""
 | |
| 
 | |
|     def __init__(self, database_path: str = ".langchain.db"):
 | |
|         """Initialize by creating the engine and all tables."""
 | |
|         engine = create_engine(f"sqlite:///{database_path}")
 | |
|         super().__init__(engine)
 | |
| 
 | |
| 
 | |
| class UpstashRedisCache(BaseCache):
 | |
|     """Cache that uses Upstash Redis as a backend."""
 | |
| 
 | |
|     def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
 | |
|         """
 | |
|         Initialize an instance of UpstashRedisCache.
 | |
| 
 | |
|         This method initializes an object with Upstash Redis caching capabilities.
 | |
|         It takes a `redis_` parameter, which should be an instance of an Upstash Redis
 | |
|         client class, allowing the object to interact with Upstash Redis
 | |
|         server for caching purposes.
 | |
| 
 | |
|         Parameters:
 | |
|             redis_: An instance of Upstash Redis client class
 | |
|                 (e.g., Redis) used for caching.
 | |
|                 This allows the object to communicate with
 | |
|                 Redis server for caching operations on.
 | |
|             ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
 | |
|                 If provided, it sets the time duration for how long cached
 | |
|                 items will remain valid. If not provided, cached items will not
 | |
|                 have an automatic expiration.
 | |
|         """
 | |
|         try:
 | |
|             from upstash_redis import Redis
 | |
|         except ImportError:
 | |
|             raise ValueError(
 | |
|                 "Could not import upstash_redis python package. "
 | |
|                 "Please install it with `pip install upstash_redis`."
 | |
|             )
 | |
|         if not isinstance(redis_, Redis):
 | |
|             raise ValueError("Please pass in Upstash Redis object.")
 | |
|         self.redis = redis_
 | |
|         self.ttl = ttl
 | |
| 
 | |
|     def _key(self, prompt: str, llm_string: str) -> str:
 | |
|         """Compute key from prompt and llm_string"""
 | |
|         return _hash(prompt + llm_string)
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up based on prompt and llm_string."""
 | |
|         generations = []
 | |
|         # Read from a HASH
 | |
|         results = self.redis.hgetall(self._key(prompt, llm_string))
 | |
|         if results:
 | |
|             for _, text in results.items():
 | |
|                 generations.append(Generation(text=text))
 | |
|         return generations if generations else None
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update cache based on prompt and llm_string."""
 | |
|         for gen in return_val:
 | |
|             if not isinstance(gen, Generation):
 | |
|                 raise ValueError(
 | |
|                     "UpstashRedisCache supports caching of normal LLM generations, "
 | |
|                     f"got {type(gen)}"
 | |
|                 )
 | |
|             if isinstance(gen, ChatGeneration):
 | |
|                 warnings.warn(
 | |
|                     "NOTE: Generation has not been cached. UpstashRedisCache does not"
 | |
|                     " support caching ChatModel outputs."
 | |
|                 )
 | |
|                 return
 | |
|         # Write to a HASH
 | |
|         key = self._key(prompt, llm_string)
 | |
| 
 | |
|         mapping = {
 | |
|             str(idx): generation.text for idx, generation in enumerate(return_val)
 | |
|         }
 | |
|         self.redis.hset(key=key, values=mapping)
 | |
| 
 | |
|         if self.ttl is not None:
 | |
|             self.redis.expire(key, self.ttl)
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """
 | |
|         Clear cache. If `asynchronous` is True, flush asynchronously.
 | |
|         This flushes the *whole* db.
 | |
|         """
 | |
|         asynchronous = kwargs.get("asynchronous", False)
 | |
|         if asynchronous:
 | |
|             asynchronous = "ASYNC"
 | |
|         else:
 | |
|             asynchronous = "SYNC"
 | |
|         self.redis.flushdb(flush_type=asynchronous)
 | |
| 
 | |
| 
 | |
| class RedisCache(BaseCache):
 | |
|     """Cache that uses Redis as a backend."""
 | |
| 
 | |
|     def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
 | |
|         """
 | |
|         Initialize an instance of RedisCache.
 | |
| 
 | |
|         This method initializes an object with Redis caching capabilities.
 | |
|         It takes a `redis_` parameter, which should be an instance of a Redis
 | |
|         client class, allowing the object to interact with a Redis
 | |
|         server for caching purposes.
 | |
| 
 | |
|         Parameters:
 | |
|             redis_ (Any): An instance of a Redis client class
 | |
|                 (e.g., redis.Redis) used for caching.
 | |
|                 This allows the object to communicate with a
 | |
|                 Redis server for caching operations.
 | |
|             ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
 | |
|                 If provided, it sets the time duration for how long cached
 | |
|                 items will remain valid. If not provided, cached items will not
 | |
|                 have an automatic expiration.
 | |
|         """
 | |
|         try:
 | |
|             from redis import Redis
 | |
|         except ImportError:
 | |
|             raise ValueError(
 | |
|                 "Could not import redis python package. "
 | |
|                 "Please install it with `pip install redis`."
 | |
|             )
 | |
|         if not isinstance(redis_, Redis):
 | |
|             raise ValueError("Please pass in Redis object.")
 | |
|         self.redis = redis_
 | |
|         self.ttl = ttl
 | |
| 
 | |
|     def _key(self, prompt: str, llm_string: str) -> str:
 | |
|         """Compute key from prompt and llm_string"""
 | |
|         return _hash(prompt + llm_string)
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up based on prompt and llm_string."""
 | |
|         generations = []
 | |
|         # Read from a Redis HASH
 | |
|         results = self.redis.hgetall(self._key(prompt, llm_string))
 | |
|         if results:
 | |
|             for _, text in results.items():
 | |
|                 try:
 | |
|                     generations.append(loads(text))
 | |
|                 except Exception:
 | |
|                     logger.warning(
 | |
|                         "Retrieving a cache value that could not be deserialized "
 | |
|                         "properly. This is likely due to the cache being in an "
 | |
|                         "older format. Please recreate your cache to avoid this "
 | |
|                         "error."
 | |
|                     )
 | |
|                     # In a previous life we stored the raw text directly
 | |
|                     # in the table, so assume it's in that format.
 | |
|                     generations.append(Generation(text=text))
 | |
|         return generations if generations else None
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update cache based on prompt and llm_string."""
 | |
|         for gen in return_val:
 | |
|             if not isinstance(gen, Generation):
 | |
|                 raise ValueError(
 | |
|                     "RedisCache only supports caching of normal LLM generations, "
 | |
|                     f"got {type(gen)}"
 | |
|                 )
 | |
|         # Write to a Redis HASH
 | |
|         key = self._key(prompt, llm_string)
 | |
| 
 | |
|         with self.redis.pipeline() as pipe:
 | |
|             pipe.hset(
 | |
|                 key,
 | |
|                 mapping={
 | |
|                     str(idx): dumps(generation)
 | |
|                     for idx, generation in enumerate(return_val)
 | |
|                 },
 | |
|             )
 | |
|             if self.ttl is not None:
 | |
|                 pipe.expire(key, self.ttl)
 | |
| 
 | |
|             pipe.execute()
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear cache. If `asynchronous` is True, flush asynchronously."""
 | |
|         asynchronous = kwargs.get("asynchronous", False)
 | |
|         self.redis.flushdb(asynchronous=asynchronous, **kwargs)
 | |
| 
 | |
| 
 | |
| class RedisSemanticCache(BaseCache):
 | |
|     """Cache that uses Redis as a vector-store backend."""
 | |
| 
 | |
|     # TODO - implement a TTL policy in Redis
 | |
| 
 | |
|     DEFAULT_SCHEMA = {
 | |
|         "content_key": "prompt",
 | |
|         "text": [
 | |
|             {"name": "prompt"},
 | |
|         ],
 | |
|         "extra": [{"name": "return_val"}, {"name": "llm_string"}],
 | |
|     }
 | |
| 
 | |
|     def __init__(
 | |
|         self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
 | |
|     ):
 | |
|         """Initialize by passing in the `init` GPTCache func
 | |
| 
 | |
|         Args:
 | |
|             redis_url (str): URL to connect to Redis.
 | |
|             embedding (Embedding): Embedding provider for semantic encoding and search.
 | |
|             score_threshold (float, 0.2):
 | |
| 
 | |
|         Example:
 | |
| 
 | |
|         .. code-block:: python
 | |
| 
 | |
|             from langchain_community.globals import set_llm_cache
 | |
| 
 | |
|             from langchain_community.cache import RedisSemanticCache
 | |
|             from langchain_community.embeddings import OpenAIEmbeddings
 | |
| 
 | |
|             set_llm_cache(RedisSemanticCache(
 | |
|                 redis_url="redis://localhost:6379",
 | |
|                 embedding=OpenAIEmbeddings()
 | |
|             ))
 | |
| 
 | |
|         """
 | |
|         self._cache_dict: Dict[str, RedisVectorstore] = {}
 | |
|         self.redis_url = redis_url
 | |
|         self.embedding = embedding
 | |
|         self.score_threshold = score_threshold
 | |
| 
 | |
|     def _index_name(self, llm_string: str) -> str:
 | |
|         hashed_index = _hash(llm_string)
 | |
|         return f"cache:{hashed_index}"
 | |
| 
 | |
|     def _get_llm_cache(self, llm_string: str) -> RedisVectorstore:
 | |
|         index_name = self._index_name(llm_string)
 | |
| 
 | |
|         # return vectorstore client for the specific llm string
 | |
|         if index_name in self._cache_dict:
 | |
|             return self._cache_dict[index_name]
 | |
| 
 | |
|         # create new vectorstore client for the specific llm string
 | |
|         try:
 | |
|             self._cache_dict[index_name] = RedisVectorstore.from_existing_index(
 | |
|                 embedding=self.embedding,
 | |
|                 index_name=index_name,
 | |
|                 redis_url=self.redis_url,
 | |
|                 schema=cast(Dict, self.DEFAULT_SCHEMA),
 | |
|             )
 | |
|         except ValueError:
 | |
|             redis = RedisVectorstore(
 | |
|                 embedding=self.embedding,
 | |
|                 index_name=index_name,
 | |
|                 redis_url=self.redis_url,
 | |
|                 index_schema=cast(Dict, self.DEFAULT_SCHEMA),
 | |
|             )
 | |
|             _embedding = self.embedding.embed_query(text="test")
 | |
|             redis._create_index_if_not_exist(dim=len(_embedding))
 | |
|             self._cache_dict[index_name] = redis
 | |
| 
 | |
|         return self._cache_dict[index_name]
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear semantic cache for a given llm_string."""
 | |
|         index_name = self._index_name(kwargs["llm_string"])
 | |
|         if index_name in self._cache_dict:
 | |
|             self._cache_dict[index_name].drop_index(
 | |
|                 index_name=index_name, delete_documents=True, redis_url=self.redis_url
 | |
|             )
 | |
|             del self._cache_dict[index_name]
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up based on prompt and llm_string."""
 | |
|         llm_cache = self._get_llm_cache(llm_string)
 | |
|         generations: List = []
 | |
|         # Read from a Hash
 | |
|         results = llm_cache.similarity_search(
 | |
|             query=prompt,
 | |
|             k=1,
 | |
|             distance_threshold=self.score_threshold,
 | |
|         )
 | |
|         if results:
 | |
|             for document in results:
 | |
|                 try:
 | |
|                     generations.extend(loads(document.metadata["return_val"]))
 | |
|                 except Exception:
 | |
|                     logger.warning(
 | |
|                         "Retrieving a cache value that could not be deserialized "
 | |
|                         "properly. This is likely due to the cache being in an "
 | |
|                         "older format. Please recreate your cache to avoid this "
 | |
|                         "error."
 | |
|                     )
 | |
|                     # In a previous life we stored the raw text directly
 | |
|                     # in the table, so assume it's in that format.
 | |
|                     generations.extend(
 | |
|                         _load_generations_from_json(document.metadata["return_val"])
 | |
|                     )
 | |
|         return generations if generations else None
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update cache based on prompt and llm_string."""
 | |
|         for gen in return_val:
 | |
|             if not isinstance(gen, Generation):
 | |
|                 raise ValueError(
 | |
|                     "RedisSemanticCache only supports caching of "
 | |
|                     f"normal LLM generations, got {type(gen)}"
 | |
|                 )
 | |
|         llm_cache = self._get_llm_cache(llm_string)
 | |
| 
 | |
|         metadata = {
 | |
|             "llm_string": llm_string,
 | |
|             "prompt": prompt,
 | |
|             "return_val": dumps([g for g in return_val]),
 | |
|         }
 | |
|         llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
 | |
| 
 | |
| 
 | |
| class GPTCache(BaseCache):
 | |
|     """Cache that uses GPTCache as a backend."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         init_func: Union[
 | |
|             Callable[[Any, str], None], Callable[[Any], None], None
 | |
|         ] = None,
 | |
|     ):
 | |
|         """Initialize by passing in init function (default: `None`).
 | |
| 
 | |
|         Args:
 | |
|             init_func (Optional[Callable[[Any], None]]): init `GPTCache` function
 | |
|             (default: `None`)
 | |
| 
 | |
|         Example:
 | |
|         .. code-block:: python
 | |
| 
 | |
|             # Initialize GPTCache with a custom init function
 | |
|             import gptcache
 | |
|             from gptcache.processor.pre import get_prompt
 | |
|             from gptcache.manager.factory import get_data_manager
 | |
|             from langchain_community.globals import set_llm_cache
 | |
| 
 | |
|             # Avoid multiple caches using the same file,
 | |
|             causing different llm model caches to affect each other
 | |
| 
 | |
|             def init_gptcache(cache_obj: gptcache.Cache, llm str):
 | |
|                 cache_obj.init(
 | |
|                     pre_embedding_func=get_prompt,
 | |
|                     data_manager=manager_factory(
 | |
|                         manager="map",
 | |
|                         data_dir=f"map_cache_{llm}"
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|             set_llm_cache(GPTCache(init_gptcache))
 | |
| 
 | |
|         """
 | |
|         try:
 | |
|             import gptcache  # noqa: F401
 | |
|         except ImportError:
 | |
|             raise ImportError(
 | |
|                 "Could not import gptcache python package. "
 | |
|                 "Please install it with `pip install gptcache`."
 | |
|             )
 | |
| 
 | |
|         self.init_gptcache_func: Union[
 | |
|             Callable[[Any, str], None], Callable[[Any], None], None
 | |
|         ] = init_func
 | |
|         self.gptcache_dict: Dict[str, Any] = {}
 | |
| 
 | |
|     def _new_gptcache(self, llm_string: str) -> Any:
 | |
|         """New gptcache object"""
 | |
|         from gptcache import Cache
 | |
|         from gptcache.manager.factory import get_data_manager
 | |
|         from gptcache.processor.pre import get_prompt
 | |
| 
 | |
|         _gptcache = Cache()
 | |
|         if self.init_gptcache_func is not None:
 | |
|             sig = inspect.signature(self.init_gptcache_func)
 | |
|             if len(sig.parameters) == 2:
 | |
|                 self.init_gptcache_func(_gptcache, llm_string)  # type: ignore[call-arg]
 | |
|             else:
 | |
|                 self.init_gptcache_func(_gptcache)  # type: ignore[call-arg]
 | |
|         else:
 | |
|             _gptcache.init(
 | |
|                 pre_embedding_func=get_prompt,
 | |
|                 data_manager=get_data_manager(data_path=llm_string),
 | |
|             )
 | |
| 
 | |
|         self.gptcache_dict[llm_string] = _gptcache
 | |
|         return _gptcache
 | |
| 
 | |
|     def _get_gptcache(self, llm_string: str) -> Any:
 | |
|         """Get a cache object.
 | |
| 
 | |
|         When the corresponding llm model cache does not exist, it will be created."""
 | |
|         _gptcache = self.gptcache_dict.get(llm_string, None)
 | |
|         if not _gptcache:
 | |
|             _gptcache = self._new_gptcache(llm_string)
 | |
|         return _gptcache
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up the cache data.
 | |
|         First, retrieve the corresponding cache object using the `llm_string` parameter,
 | |
|         and then retrieve the data from the cache based on the `prompt`.
 | |
|         """
 | |
|         from gptcache.adapter.api import get
 | |
| 
 | |
|         _gptcache = self._get_gptcache(llm_string)
 | |
| 
 | |
|         res = get(prompt, cache_obj=_gptcache)
 | |
|         if res:
 | |
|             return [
 | |
|                 Generation(**generation_dict) for generation_dict in json.loads(res)
 | |
|             ]
 | |
|         return None
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update cache.
 | |
|         First, retrieve the corresponding cache object using the `llm_string` parameter,
 | |
|         and then store the `prompt` and `return_val` in the cache object.
 | |
|         """
 | |
|         for gen in return_val:
 | |
|             if not isinstance(gen, Generation):
 | |
|                 raise ValueError(
 | |
|                     "GPTCache only supports caching of normal LLM generations, "
 | |
|                     f"got {type(gen)}"
 | |
|                 )
 | |
|         from gptcache.adapter.api import put
 | |
| 
 | |
|         _gptcache = self._get_gptcache(llm_string)
 | |
|         handled_data = json.dumps([generation.dict() for generation in return_val])
 | |
|         put(prompt, handled_data, cache_obj=_gptcache)
 | |
|         return None
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear cache."""
 | |
|         from gptcache import Cache
 | |
| 
 | |
|         for gptcache_instance in self.gptcache_dict.values():
 | |
|             gptcache_instance = cast(Cache, gptcache_instance)
 | |
|             gptcache_instance.flush()
 | |
| 
 | |
|         self.gptcache_dict.clear()
 | |
| 
 | |
| 
 | |
| def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None:
 | |
|     """Create cache if it doesn't exist.
 | |
| 
 | |
|     Raises:
 | |
|         SdkException: Momento service or network error
 | |
|         Exception: Unexpected response
 | |
|     """
 | |
|     from momento.responses import CreateCache
 | |
| 
 | |
|     create_cache_response = cache_client.create_cache(cache_name)
 | |
|     if isinstance(create_cache_response, CreateCache.Success) or isinstance(
 | |
|         create_cache_response, CreateCache.CacheAlreadyExists
 | |
|     ):
 | |
|         return None
 | |
|     elif isinstance(create_cache_response, CreateCache.Error):
 | |
|         raise create_cache_response.inner_exception
 | |
|     else:
 | |
|         raise Exception(f"Unexpected response cache creation: {create_cache_response}")
 | |
| 
 | |
| 
 | |
| def _validate_ttl(ttl: Optional[timedelta]) -> None:
 | |
|     if ttl is not None and ttl <= timedelta(seconds=0):
 | |
|         raise ValueError(f"ttl must be positive but was {ttl}.")
 | |
| 
 | |
| 
 | |
| class MomentoCache(BaseCache):
 | |
|     """Cache that uses Momento as a backend. See https://gomomento.com/"""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         cache_client: momento.CacheClient,
 | |
|         cache_name: str,
 | |
|         *,
 | |
|         ttl: Optional[timedelta] = None,
 | |
|         ensure_cache_exists: bool = True,
 | |
|     ):
 | |
|         """Instantiate a prompt cache using Momento as a backend.
 | |
| 
 | |
|         Note: to instantiate the cache client passed to MomentoCache,
 | |
|         you must have a Momento account. See https://gomomento.com/.
 | |
| 
 | |
|         Args:
 | |
|             cache_client (CacheClient): The Momento cache client.
 | |
|             cache_name (str): The name of the cache to use to store the data.
 | |
|             ttl (Optional[timedelta], optional): The time to live for the cache items.
 | |
|                 Defaults to None, ie use the client default TTL.
 | |
|             ensure_cache_exists (bool, optional): Create the cache if it doesn't
 | |
|                 exist. Defaults to True.
 | |
| 
 | |
|         Raises:
 | |
|             ImportError: Momento python package is not installed.
 | |
|             TypeError: cache_client is not of type momento.CacheClientObject
 | |
|             ValueError: ttl is non-null and non-negative
 | |
|         """
 | |
|         try:
 | |
|             from momento import CacheClient
 | |
|         except ImportError:
 | |
|             raise ImportError(
 | |
|                 "Could not import momento python package. "
 | |
|                 "Please install it with `pip install momento`."
 | |
|             )
 | |
|         if not isinstance(cache_client, CacheClient):
 | |
|             raise TypeError("cache_client must be a momento.CacheClient object.")
 | |
|         _validate_ttl(ttl)
 | |
|         if ensure_cache_exists:
 | |
|             _ensure_cache_exists(cache_client, cache_name)
 | |
| 
 | |
|         self.cache_client = cache_client
 | |
|         self.cache_name = cache_name
 | |
|         self.ttl = ttl
 | |
| 
 | |
|     @classmethod
 | |
|     def from_client_params(
 | |
|         cls,
 | |
|         cache_name: str,
 | |
|         ttl: timedelta,
 | |
|         *,
 | |
|         configuration: Optional[momento.config.Configuration] = None,
 | |
|         api_key: Optional[str] = None,
 | |
|         auth_token: Optional[str] = None,  # for backwards compatibility
 | |
|         **kwargs: Any,
 | |
|     ) -> MomentoCache:
 | |
|         """Construct cache from CacheClient parameters."""
 | |
|         try:
 | |
|             from momento import CacheClient, Configurations, CredentialProvider
 | |
|         except ImportError:
 | |
|             raise ImportError(
 | |
|                 "Could not import momento python package. "
 | |
|                 "Please install it with `pip install momento`."
 | |
|             )
 | |
|         if configuration is None:
 | |
|             configuration = Configurations.Laptop.v1()
 | |
| 
 | |
|         # Try checking `MOMENTO_AUTH_TOKEN` first for backwards compatibility
 | |
|         try:
 | |
|             api_key = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN")
 | |
|         except ValueError:
 | |
|             api_key = api_key or get_from_env("api_key", "MOMENTO_API_KEY")
 | |
|         credentials = CredentialProvider.from_string(api_key)
 | |
|         cache_client = CacheClient(configuration, credentials, default_ttl=ttl)
 | |
|         return cls(cache_client, cache_name, ttl=ttl, **kwargs)
 | |
| 
 | |
|     def __key(self, prompt: str, llm_string: str) -> str:
 | |
|         """Compute cache key from prompt and associated model and settings.
 | |
| 
 | |
|         Args:
 | |
|             prompt (str): The prompt run through the language model.
 | |
|             llm_string (str): The language model version and settings.
 | |
| 
 | |
|         Returns:
 | |
|             str: The cache key.
 | |
|         """
 | |
|         return _hash(prompt + llm_string)
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Lookup llm generations in cache by prompt and associated model and settings.
 | |
| 
 | |
|         Args:
 | |
|             prompt (str): The prompt run through the language model.
 | |
|             llm_string (str): The language model version and settings.
 | |
| 
 | |
|         Raises:
 | |
|             SdkException: Momento service or network error
 | |
| 
 | |
|         Returns:
 | |
|             Optional[RETURN_VAL_TYPE]: A list of language model generations.
 | |
|         """
 | |
|         from momento.responses import CacheGet
 | |
| 
 | |
|         generations: RETURN_VAL_TYPE = []
 | |
| 
 | |
|         get_response = self.cache_client.get(
 | |
|             self.cache_name, self.__key(prompt, llm_string)
 | |
|         )
 | |
|         if isinstance(get_response, CacheGet.Hit):
 | |
|             value = get_response.value_string
 | |
|             generations = _load_generations_from_json(value)
 | |
|         elif isinstance(get_response, CacheGet.Miss):
 | |
|             pass
 | |
|         elif isinstance(get_response, CacheGet.Error):
 | |
|             raise get_response.inner_exception
 | |
|         return generations if generations else None
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Store llm generations in cache.
 | |
| 
 | |
|         Args:
 | |
|             prompt (str): The prompt run through the language model.
 | |
|             llm_string (str): The language model string.
 | |
|             return_val (RETURN_VAL_TYPE): A list of language model generations.
 | |
| 
 | |
|         Raises:
 | |
|             SdkException: Momento service or network error
 | |
|             Exception: Unexpected response
 | |
|         """
 | |
|         for gen in return_val:
 | |
|             if not isinstance(gen, Generation):
 | |
|                 raise ValueError(
 | |
|                     "Momento only supports caching of normal LLM generations, "
 | |
|                     f"got {type(gen)}"
 | |
|                 )
 | |
|         key = self.__key(prompt, llm_string)
 | |
|         value = _dump_generations_to_json(return_val)
 | |
|         set_response = self.cache_client.set(self.cache_name, key, value, self.ttl)
 | |
|         from momento.responses import CacheSet
 | |
| 
 | |
|         if isinstance(set_response, CacheSet.Success):
 | |
|             pass
 | |
|         elif isinstance(set_response, CacheSet.Error):
 | |
|             raise set_response.inner_exception
 | |
|         else:
 | |
|             raise Exception(f"Unexpected response: {set_response}")
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear the cache.
 | |
| 
 | |
|         Raises:
 | |
|             SdkException: Momento service or network error
 | |
|         """
 | |
|         from momento.responses import CacheFlush
 | |
| 
 | |
|         flush_response = self.cache_client.flush_cache(self.cache_name)
 | |
|         if isinstance(flush_response, CacheFlush.Success):
 | |
|             pass
 | |
|         elif isinstance(flush_response, CacheFlush.Error):
 | |
|             raise flush_response.inner_exception
 | |
| 
 | |
| 
 | |
| CASSANDRA_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_cache"
 | |
| CASSANDRA_CACHE_DEFAULT_TTL_SECONDS = None
 | |
| 
 | |
| 
 | |
| class CassandraCache(BaseCache):
 | |
|     """
 | |
|     Cache that uses Cassandra / Astra DB as a backend.
 | |
| 
 | |
|     It uses a single Cassandra table.
 | |
|     The lookup keys (which get to form the primary key) are:
 | |
|         - prompt, a string
 | |
|         - llm_string, a deterministic str representation of the model parameters.
 | |
|           (needed to prevent collisions same-prompt-different-model collisions)
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         session: Optional[CassandraSession] = None,
 | |
|         keyspace: Optional[str] = None,
 | |
|         table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME,
 | |
|         ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS,
 | |
|         skip_provisioning: bool = False,
 | |
|     ):
 | |
|         """
 | |
|         Initialize with a ready session and a keyspace name.
 | |
|         Args:
 | |
|             session (cassandra.cluster.Session): an open Cassandra session
 | |
|             keyspace (str): the keyspace to use for storing the cache
 | |
|             table_name (str): name of the Cassandra table to use as cache
 | |
|             ttl_seconds (optional int): time-to-live for cache entries
 | |
|                 (default: None, i.e. forever)
 | |
|         """
 | |
|         try:
 | |
|             from cassio.table import ElasticCassandraTable
 | |
|         except (ImportError, ModuleNotFoundError):
 | |
|             raise ValueError(
 | |
|                 "Could not import cassio python package. "
 | |
|                 "Please install it with `pip install cassio`."
 | |
|             )
 | |
| 
 | |
|         self.session = session
 | |
|         self.keyspace = keyspace
 | |
|         self.table_name = table_name
 | |
|         self.ttl_seconds = ttl_seconds
 | |
| 
 | |
|         self.kv_cache = ElasticCassandraTable(
 | |
|             session=self.session,
 | |
|             keyspace=self.keyspace,
 | |
|             table=self.table_name,
 | |
|             keys=["llm_string", "prompt"],
 | |
|             primary_key_type=["TEXT", "TEXT"],
 | |
|             ttl_seconds=self.ttl_seconds,
 | |
|             skip_provisioning=skip_provisioning,
 | |
|         )
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up based on prompt and llm_string."""
 | |
|         item = self.kv_cache.get(
 | |
|             llm_string=_hash(llm_string),
 | |
|             prompt=_hash(prompt),
 | |
|         )
 | |
|         if item is not None:
 | |
|             generations = _loads_generations(item["body_blob"])
 | |
|             # this protects against malformed cached items:
 | |
|             if generations is not None:
 | |
|                 return generations
 | |
|             else:
 | |
|                 return None
 | |
|         else:
 | |
|             return None
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update cache based on prompt and llm_string."""
 | |
|         blob = _dumps_generations(return_val)
 | |
|         self.kv_cache.put(
 | |
|             llm_string=_hash(llm_string),
 | |
|             prompt=_hash(prompt),
 | |
|             body_blob=blob,
 | |
|         )
 | |
| 
 | |
|     def delete_through_llm(
 | |
|         self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
 | |
|     ) -> None:
 | |
|         """
 | |
|         A wrapper around `delete` with the LLM being passed.
 | |
|         In case the llm(prompt) calls have a `stop` param, you should pass it here
 | |
|         """
 | |
|         llm_string = get_prompts(
 | |
|             {**llm.dict(), **{"stop": stop}},
 | |
|             [],
 | |
|         )[1]
 | |
|         return self.delete(prompt, llm_string=llm_string)
 | |
| 
 | |
|     def delete(self, prompt: str, llm_string: str) -> None:
 | |
|         """Evict from cache if there's an entry."""
 | |
|         return self.kv_cache.delete(
 | |
|             llm_string=_hash(llm_string),
 | |
|             prompt=_hash(prompt),
 | |
|         )
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear cache. This is for all LLMs at once."""
 | |
|         self.kv_cache.clear()
 | |
| 
 | |
| 
 | |
| CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot"
 | |
| CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85
 | |
| CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_semantic_cache"
 | |
| CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS = None
 | |
| CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
 | |
| 
 | |
| 
 | |
| class CassandraSemanticCache(BaseCache):
 | |
|     """
 | |
|     Cache that uses Cassandra as a vector-store backend for semantic
 | |
|     (i.e. similarity-based) lookup.
 | |
| 
 | |
|     It uses a single (vector) Cassandra table and stores, in principle,
 | |
|     cached values from several LLMs, so the LLM's llm_string is part
 | |
|     of the rows' primary keys.
 | |
| 
 | |
|     The similarity is based on one of several distance metrics (default: "dot").
 | |
|     If choosing another metric, the default threshold is to be re-tuned accordingly.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         session: Optional[CassandraSession],
 | |
|         keyspace: Optional[str],
 | |
|         embedding: Embeddings,
 | |
|         table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME,
 | |
|         distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC,
 | |
|         score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD,
 | |
|         ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS,
 | |
|         skip_provisioning: bool = False,
 | |
|     ):
 | |
|         """
 | |
|         Initialize the cache with all relevant parameters.
 | |
|         Args:
 | |
|             session (cassandra.cluster.Session): an open Cassandra session
 | |
|             keyspace (str): the keyspace to use for storing the cache
 | |
|             embedding (Embedding): Embedding provider for semantic
 | |
|                 encoding and search.
 | |
|             table_name (str): name of the Cassandra (vector) table
 | |
|                 to use as cache
 | |
|             distance_metric (str, 'dot'): which measure to adopt for
 | |
|                 similarity searches
 | |
|             score_threshold (optional float): numeric value to use as
 | |
|                 cutoff for the similarity searches
 | |
|             ttl_seconds (optional int): time-to-live for cache entries
 | |
|                 (default: None, i.e. forever)
 | |
|         The default score threshold is tuned to the default metric.
 | |
|         Tune it carefully yourself if switching to another distance metric.
 | |
|         """
 | |
|         try:
 | |
|             from cassio.table import MetadataVectorCassandraTable
 | |
|         except (ImportError, ModuleNotFoundError):
 | |
|             raise ValueError(
 | |
|                 "Could not import cassio python package. "
 | |
|                 "Please install it with `pip install cassio`."
 | |
|             )
 | |
|         self.session = session
 | |
|         self.keyspace = keyspace
 | |
|         self.embedding = embedding
 | |
|         self.table_name = table_name
 | |
|         self.distance_metric = distance_metric
 | |
|         self.score_threshold = score_threshold
 | |
|         self.ttl_seconds = ttl_seconds
 | |
| 
 | |
|         # The contract for this class has separate lookup and update:
 | |
|         # in order to spare some embedding calculations we cache them between
 | |
|         # the two calls.
 | |
|         # Note: each instance of this class has its own `_get_embedding` with
 | |
|         # its own lru.
 | |
|         @lru_cache(maxsize=CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
 | |
|         def _cache_embedding(text: str) -> List[float]:
 | |
|             return self.embedding.embed_query(text=text)
 | |
| 
 | |
|         self._get_embedding = _cache_embedding
 | |
|         self.embedding_dimension = self._get_embedding_dimension()
 | |
| 
 | |
|         self.table = MetadataVectorCassandraTable(
 | |
|             session=self.session,
 | |
|             keyspace=self.keyspace,
 | |
|             table=self.table_name,
 | |
|             primary_key_type=["TEXT"],
 | |
|             vector_dimension=self.embedding_dimension,
 | |
|             ttl_seconds=self.ttl_seconds,
 | |
|             metadata_indexing=("allow", {"_llm_string_hash"}),
 | |
|             skip_provisioning=skip_provisioning,
 | |
|         )
 | |
| 
 | |
|     def _get_embedding_dimension(self) -> int:
 | |
|         return len(self._get_embedding(text="This is a sample sentence."))
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update cache based on prompt and llm_string."""
 | |
|         embedding_vector = self._get_embedding(text=prompt)
 | |
|         llm_string_hash = _hash(llm_string)
 | |
|         body = _dumps_generations(return_val)
 | |
|         metadata = {
 | |
|             "_prompt": prompt,
 | |
|             "_llm_string_hash": llm_string_hash,
 | |
|         }
 | |
|         row_id = f"{_hash(prompt)}-{llm_string_hash}"
 | |
|         #
 | |
|         self.table.put(
 | |
|             body_blob=body,
 | |
|             vector=embedding_vector,
 | |
|             row_id=row_id,
 | |
|             metadata=metadata,
 | |
|         )
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up based on prompt and llm_string."""
 | |
|         hit_with_id = self.lookup_with_id(prompt, llm_string)
 | |
|         if hit_with_id is not None:
 | |
|             return hit_with_id[1]
 | |
|         else:
 | |
|             return None
 | |
| 
 | |
|     def lookup_with_id(
 | |
|         self, prompt: str, llm_string: str
 | |
|     ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
 | |
|         """
 | |
|         Look up based on prompt and llm_string.
 | |
|         If there are hits, return (document_id, cached_entry)
 | |
|         """
 | |
|         prompt_embedding: List[float] = self._get_embedding(text=prompt)
 | |
|         hits = list(
 | |
|             self.table.metric_ann_search(
 | |
|                 vector=prompt_embedding,
 | |
|                 metadata={"_llm_string_hash": _hash(llm_string)},
 | |
|                 n=1,
 | |
|                 metric=self.distance_metric,
 | |
|                 metric_threshold=self.score_threshold,
 | |
|             )
 | |
|         )
 | |
|         if hits:
 | |
|             hit = hits[0]
 | |
|             generations = _loads_generations(hit["body_blob"])
 | |
|             if generations is not None:
 | |
|                 # this protects against malformed cached items:
 | |
|                 return (
 | |
|                     hit["row_id"],
 | |
|                     generations,
 | |
|                 )
 | |
|             else:
 | |
|                 return None
 | |
|         else:
 | |
|             return None
 | |
| 
 | |
|     def lookup_with_id_through_llm(
 | |
|         self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
 | |
|     ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
 | |
|         llm_string = get_prompts(
 | |
|             {**llm.dict(), **{"stop": stop}},
 | |
|             [],
 | |
|         )[1]
 | |
|         return self.lookup_with_id(prompt, llm_string=llm_string)
 | |
| 
 | |
|     def delete_by_document_id(self, document_id: str) -> None:
 | |
|         """
 | |
|         Given this is a "similarity search" cache, an invalidation pattern
 | |
|         that makes sense is first a lookup to get an ID, and then deleting
 | |
|         with that ID. This is for the second step.
 | |
|         """
 | |
|         self.table.delete(row_id=document_id)
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear the *whole* semantic cache."""
 | |
|         self.table.clear()
 | |
| 
 | |
| 
 | |
| class FullMd5LLMCache(Base):  # type: ignore
 | |
|     """SQLite table for full LLM Cache (all generations)."""
 | |
| 
 | |
|     __tablename__ = "full_md5_llm_cache"
 | |
|     id = Column(String, primary_key=True)
 | |
|     prompt_md5 = Column(String, index=True)
 | |
|     llm = Column(String, index=True)
 | |
|     idx = Column(Integer, index=True)
 | |
|     prompt = Column(String)
 | |
|     response = Column(String)
 | |
| 
 | |
| 
 | |
| class SQLAlchemyMd5Cache(BaseCache):
 | |
|     """Cache that uses SQAlchemy as a backend."""
 | |
| 
 | |
|     def __init__(
 | |
|         self, engine: Engine, cache_schema: Type[FullMd5LLMCache] = FullMd5LLMCache
 | |
|     ):
 | |
|         """Initialize by creating all tables."""
 | |
|         self.engine = engine
 | |
|         self.cache_schema = cache_schema
 | |
|         self.cache_schema.metadata.create_all(self.engine)
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up based on prompt and llm_string."""
 | |
|         rows = self._search_rows(prompt, llm_string)
 | |
|         if rows:
 | |
|             return [loads(row[0]) for row in rows]
 | |
|         return None
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update based on prompt and llm_string."""
 | |
|         self._delete_previous(prompt, llm_string)
 | |
|         prompt_md5 = self.get_md5(prompt)
 | |
|         items = [
 | |
|             self.cache_schema(
 | |
|                 id=str(uuid.uuid1()),
 | |
|                 prompt=prompt,
 | |
|                 prompt_md5=prompt_md5,
 | |
|                 llm=llm_string,
 | |
|                 response=dumps(gen),
 | |
|                 idx=i,
 | |
|             )
 | |
|             for i, gen in enumerate(return_val)
 | |
|         ]
 | |
|         with Session(self.engine) as session, session.begin():
 | |
|             for item in items:
 | |
|                 session.merge(item)
 | |
| 
 | |
|     def _delete_previous(self, prompt: str, llm_string: str) -> None:
 | |
|         stmt = (
 | |
|             select(self.cache_schema.response)
 | |
|             .where(self.cache_schema.prompt_md5 == self.get_md5(prompt))  # type: ignore
 | |
|             .where(self.cache_schema.llm == llm_string)
 | |
|             .where(self.cache_schema.prompt == prompt)
 | |
|             .order_by(self.cache_schema.idx)
 | |
|         )
 | |
|         with Session(self.engine) as session, session.begin():
 | |
|             rows = session.execute(stmt).fetchall()
 | |
|             for item in rows:
 | |
|                 session.delete(item)
 | |
| 
 | |
|     def _search_rows(self, prompt: str, llm_string: str) -> List[Row]:
 | |
|         prompt_pd5 = self.get_md5(prompt)
 | |
|         stmt = (
 | |
|             select(self.cache_schema.response)
 | |
|             .where(self.cache_schema.prompt_md5 == prompt_pd5)  # type: ignore
 | |
|             .where(self.cache_schema.llm == llm_string)
 | |
|             .where(self.cache_schema.prompt == prompt)
 | |
|             .order_by(self.cache_schema.idx)
 | |
|         )
 | |
|         with Session(self.engine) as session:
 | |
|             return session.execute(stmt).fetchall()
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear cache."""
 | |
|         with Session(self.engine) as session:
 | |
|             session.execute(self.cache_schema.delete())
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_md5(input_string: str) -> str:
 | |
|         return hashlib.md5(input_string.encode()).hexdigest()
 | |
| 
 | |
| 
 | |
| ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache"
 | |
| 
 | |
| 
 | |
| class AstraDBCache(BaseCache):
 | |
|     """
 | |
|     Cache that uses Astra DB as a backend.
 | |
| 
 | |
|     It uses a single collection as a kv store
 | |
|     The lookup keys, combined in the _id of the documents, are:
 | |
|         - prompt, a string
 | |
|         - llm_string, a deterministic str representation of the model parameters.
 | |
|           (needed to prevent same-prompt-different-model collisions)
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         *,
 | |
|         collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
 | |
|         token: Optional[str] = None,
 | |
|         api_endpoint: Optional[str] = None,
 | |
|         astra_db_client: Optional[Any] = None,  # 'astrapy.db.AstraDB' if passed
 | |
|         namespace: Optional[str] = None,
 | |
|     ):
 | |
|         """
 | |
|         Create an AstraDB cache using a collection for storage.
 | |
| 
 | |
|         Args (only keyword-arguments accepted):
 | |
|             collection_name (str): name of the Astra DB collection to create/use.
 | |
|             token (Optional[str]): API token for Astra DB usage.
 | |
|             api_endpoint (Optional[str]): full URL to the API endpoint,
 | |
|                 such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
 | |
|             astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
 | |
|                 you can pass an already-created 'astrapy.db.AstraDB' instance.
 | |
|             namespace (Optional[str]): namespace (aka keyspace) where the
 | |
|                 collection is created. Defaults to the database's "default namespace".
 | |
|         """
 | |
|         try:
 | |
|             from astrapy.db import (
 | |
|                 AstraDB as LibAstraDB,
 | |
|             )
 | |
|         except (ImportError, ModuleNotFoundError):
 | |
|             raise ImportError(
 | |
|                 "Could not import a recent astrapy python package. "
 | |
|                 "Please install it with `pip install --upgrade astrapy`."
 | |
|             )
 | |
|         # Conflicting-arg checks:
 | |
|         if astra_db_client is not None:
 | |
|             if token is not None or api_endpoint is not None:
 | |
|                 raise ValueError(
 | |
|                     "You cannot pass 'astra_db_client' to AstraDB if passing "
 | |
|                     "'token' and 'api_endpoint'."
 | |
|                 )
 | |
| 
 | |
|         self.collection_name = collection_name
 | |
|         self.token = token
 | |
|         self.api_endpoint = api_endpoint
 | |
|         self.namespace = namespace
 | |
| 
 | |
|         if astra_db_client is not None:
 | |
|             self.astra_db = astra_db_client
 | |
|         else:
 | |
|             self.astra_db = LibAstraDB(
 | |
|                 token=self.token,
 | |
|                 api_endpoint=self.api_endpoint,
 | |
|                 namespace=self.namespace,
 | |
|             )
 | |
|         self.collection = self.astra_db.create_collection(
 | |
|             collection_name=self.collection_name,
 | |
|         )
 | |
| 
 | |
|     @staticmethod
 | |
|     def _make_id(prompt: str, llm_string: str) -> str:
 | |
|         return f"{_hash(prompt)}#{_hash(llm_string)}"
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up based on prompt and llm_string."""
 | |
|         doc_id = self._make_id(prompt, llm_string)
 | |
|         item = self.collection.find_one(
 | |
|             filter={
 | |
|                 "_id": doc_id,
 | |
|             },
 | |
|             projection={
 | |
|                 "body_blob": 1,
 | |
|             },
 | |
|         )["data"]["document"]
 | |
|         if item is not None:
 | |
|             generations = _loads_generations(item["body_blob"])
 | |
|             # this protects against malformed cached items:
 | |
|             if generations is not None:
 | |
|                 return generations
 | |
|             else:
 | |
|                 return None
 | |
|         else:
 | |
|             return None
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update cache based on prompt and llm_string."""
 | |
|         doc_id = self._make_id(prompt, llm_string)
 | |
|         blob = _dumps_generations(return_val)
 | |
|         self.collection.upsert(
 | |
|             {
 | |
|                 "_id": doc_id,
 | |
|                 "body_blob": blob,
 | |
|             },
 | |
|         )
 | |
| 
 | |
|     def delete_through_llm(
 | |
|         self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
 | |
|     ) -> None:
 | |
|         """
 | |
|         A wrapper around `delete` with the LLM being passed.
 | |
|         In case the llm(prompt) calls have a `stop` param, you should pass it here
 | |
|         """
 | |
|         llm_string = get_prompts(
 | |
|             {**llm.dict(), **{"stop": stop}},
 | |
|             [],
 | |
|         )[1]
 | |
|         return self.delete(prompt, llm_string=llm_string)
 | |
| 
 | |
|     def delete(self, prompt: str, llm_string: str) -> None:
 | |
|         """Evict from cache if there's an entry."""
 | |
|         doc_id = self._make_id(prompt, llm_string)
 | |
|         return self.collection.delete_one(doc_id)
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear cache. This is for all LLMs at once."""
 | |
|         self.astra_db.truncate_collection(self.collection_name)
 | |
| 
 | |
| 
 | |
| ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85
 | |
| ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache"
 | |
| ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
 | |
| 
 | |
| 
 | |
| class AstraDBSemanticCache(BaseCache):
 | |
|     """
 | |
|     Cache that uses Astra DB as a vector-store backend for semantic
 | |
|     (i.e. similarity-based) lookup.
 | |
| 
 | |
|     It uses a single (vector) collection and can store
 | |
|     cached values from several LLMs, so the LLM's 'llm_string' is stored
 | |
|     in the document metadata.
 | |
| 
 | |
|     You can choose the preferred similarity (or use the API default) --
 | |
|     remember the threshold might require metric-dependend tuning.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         *,
 | |
|         collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
 | |
|         token: Optional[str] = None,
 | |
|         api_endpoint: Optional[str] = None,
 | |
|         astra_db_client: Optional[Any] = None,  # 'astrapy.db.AstraDB' if passed
 | |
|         namespace: Optional[str] = None,
 | |
|         embedding: Embeddings,
 | |
|         metric: Optional[str] = None,
 | |
|         similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,
 | |
|     ):
 | |
|         """
 | |
|         Initialize the cache with all relevant parameters.
 | |
|         Args:
 | |
| 
 | |
|             collection_name (str): name of the Astra DB collection to create/use.
 | |
|             token (Optional[str]): API token for Astra DB usage.
 | |
|             api_endpoint (Optional[str]): full URL to the API endpoint,
 | |
|                 such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
 | |
|             astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
 | |
|                 you can pass an already-created 'astrapy.db.AstraDB' instance.
 | |
|             namespace (Optional[str]): namespace (aka keyspace) where the
 | |
|                 collection is created. Defaults to the database's "default namespace".
 | |
|             embedding (Embedding): Embedding provider for semantic
 | |
|                 encoding and search.
 | |
|             metric: the function to use for evaluating similarity of text embeddings.
 | |
|                 Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product')
 | |
|             similarity_threshold (float, optional): the minimum similarity
 | |
|                 for accepting a (semantic-search) match.
 | |
| 
 | |
|         The default score threshold is tuned to the default metric.
 | |
|         Tune it carefully yourself if switching to another distance metric.
 | |
|         """
 | |
|         try:
 | |
|             from astrapy.db import (
 | |
|                 AstraDB as LibAstraDB,
 | |
|             )
 | |
|         except (ImportError, ModuleNotFoundError):
 | |
|             raise ImportError(
 | |
|                 "Could not import a recent astrapy python package. "
 | |
|                 "Please install it with `pip install --upgrade astrapy`."
 | |
|             )
 | |
|         # Conflicting-arg checks:
 | |
|         if astra_db_client is not None:
 | |
|             if token is not None or api_endpoint is not None:
 | |
|                 raise ValueError(
 | |
|                     "You cannot pass 'astra_db_client' to AstraDB if passing "
 | |
|                     "'token' and 'api_endpoint'."
 | |
|                 )
 | |
| 
 | |
|         self.embedding = embedding
 | |
|         self.metric = metric
 | |
|         self.similarity_threshold = similarity_threshold
 | |
| 
 | |
|         # The contract for this class has separate lookup and update:
 | |
|         # in order to spare some embedding calculations we cache them between
 | |
|         # the two calls.
 | |
|         # Note: each instance of this class has its own `_get_embedding` with
 | |
|         # its own lru.
 | |
|         @lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
 | |
|         def _cache_embedding(text: str) -> List[float]:
 | |
|             return self.embedding.embed_query(text=text)
 | |
| 
 | |
|         self._get_embedding = _cache_embedding
 | |
|         self.embedding_dimension = self._get_embedding_dimension()
 | |
| 
 | |
|         self.collection_name = collection_name
 | |
|         self.token = token
 | |
|         self.api_endpoint = api_endpoint
 | |
|         self.namespace = namespace
 | |
| 
 | |
|         if astra_db_client is not None:
 | |
|             self.astra_db = astra_db_client
 | |
|         else:
 | |
|             self.astra_db = LibAstraDB(
 | |
|                 token=self.token,
 | |
|                 api_endpoint=self.api_endpoint,
 | |
|                 namespace=self.namespace,
 | |
|             )
 | |
|         self.collection = self.astra_db.create_collection(
 | |
|             collection_name=self.collection_name,
 | |
|             dimension=self.embedding_dimension,
 | |
|             metric=self.metric,
 | |
|         )
 | |
| 
 | |
|     def _get_embedding_dimension(self) -> int:
 | |
|         return len(self._get_embedding(text="This is a sample sentence."))
 | |
| 
 | |
|     @staticmethod
 | |
|     def _make_id(prompt: str, llm_string: str) -> str:
 | |
|         return f"{_hash(prompt)}#{_hash(llm_string)}"
 | |
| 
 | |
|     def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
 | |
|         """Update cache based on prompt and llm_string."""
 | |
|         doc_id = self._make_id(prompt, llm_string)
 | |
|         llm_string_hash = _hash(llm_string)
 | |
|         embedding_vector = self._get_embedding(text=prompt)
 | |
|         body = _dumps_generations(return_val)
 | |
|         #
 | |
|         self.collection.upsert(
 | |
|             {
 | |
|                 "_id": doc_id,
 | |
|                 "body_blob": body,
 | |
|                 "llm_string_hash": llm_string_hash,
 | |
|                 "$vector": embedding_vector,
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
 | |
|         """Look up based on prompt and llm_string."""
 | |
|         hit_with_id = self.lookup_with_id(prompt, llm_string)
 | |
|         if hit_with_id is not None:
 | |
|             return hit_with_id[1]
 | |
|         else:
 | |
|             return None
 | |
| 
 | |
|     def lookup_with_id(
 | |
|         self, prompt: str, llm_string: str
 | |
|     ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
 | |
|         """
 | |
|         Look up based on prompt and llm_string.
 | |
|         If there are hits, return (document_id, cached_entry) for the top hit
 | |
|         """
 | |
|         prompt_embedding: List[float] = self._get_embedding(text=prompt)
 | |
|         llm_string_hash = _hash(llm_string)
 | |
| 
 | |
|         hit = self.collection.vector_find_one(
 | |
|             vector=prompt_embedding,
 | |
|             filter={
 | |
|                 "llm_string_hash": llm_string_hash,
 | |
|             },
 | |
|             fields=["body_blob", "_id"],
 | |
|             include_similarity=True,
 | |
|         )
 | |
| 
 | |
|         if hit is None or hit["$similarity"] < self.similarity_threshold:
 | |
|             return None
 | |
|         else:
 | |
|             generations = _loads_generations(hit["body_blob"])
 | |
|             if generations is not None:
 | |
|                 # this protects against malformed cached items:
 | |
|                 return (hit["_id"], generations)
 | |
|             else:
 | |
|                 return None
 | |
| 
 | |
|     def lookup_with_id_through_llm(
 | |
|         self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
 | |
|     ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
 | |
|         llm_string = get_prompts(
 | |
|             {**llm.dict(), **{"stop": stop}},
 | |
|             [],
 | |
|         )[1]
 | |
|         return self.lookup_with_id(prompt, llm_string=llm_string)
 | |
| 
 | |
|     def delete_by_document_id(self, document_id: str) -> None:
 | |
|         """
 | |
|         Given this is a "similarity search" cache, an invalidation pattern
 | |
|         that makes sense is first a lookup to get an ID, and then deleting
 | |
|         with that ID. This is for the second step.
 | |
|         """
 | |
|         self.collection.delete_one(document_id)
 | |
| 
 | |
|     def clear(self, **kwargs: Any) -> None:
 | |
|         """Clear the *whole* semantic cache."""
 | |
|         self.astra_db.truncate_collection(self.collection_name)
 |