Allow clearing cache and fix gptcache (#3493)

This PR

* Adds `clear` method for `BaseCache` and implements it for various
caches
* Adds the default `init_func=None` and fixes gptcache integtest
* Since right now integtest is not running in CI, I've verified the
changes by running `docs/modules/models/llms/examples/llm_caching.ipynb`
(until proper e2e integtest is done in CI)
This commit is contained in:
Ehsan M. Kermani 2023-04-26 22:03:50 -07:00 committed by GitHub
parent 83e871f1ff
commit 4a246e2fd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 60 deletions

6
.gitignore vendored
View File

@ -144,4 +144,8 @@ wandb/
/.ruff_cache/ /.ruff_cache/
*.pkl *.pkl
*.bin *.bin
# integration test artifacts
data_map*
\[('_type', 'fake'), ('stop', None)]

View File

@ -785,7 +785,9 @@
"id": "9df0dab8", "id": "9df0dab8",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [
"!rm .langchain.db sqlite.db"
]
} }
], ],
"metadata": { "metadata": {

View File

@ -1,7 +1,7 @@
"""Beta Feature: base interface for cache.""" """Beta Feature: base interface for cache."""
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
@ -28,6 +28,10 @@ class BaseCache(ABC):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
class InMemoryCache(BaseCache): class InMemoryCache(BaseCache):
"""Cache that stores things in memory.""" """Cache that stores things in memory."""
@ -44,6 +48,10 @@ class InMemoryCache(BaseCache):
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val self._cache[(prompt, llm_string)] = return_val
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}
Base = declarative_base() Base = declarative_base()
@ -61,7 +69,7 @@ class FullLLMCache(Base): # type: ignore
class SQLAlchemyCache(BaseCache): class SQLAlchemyCache(BaseCache):
"""Cache that uses SQAlchemy as a backend.""" """Cache that uses SQAlchemy as a backend."""
def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache): def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache):
"""Initialize by creating all tables.""" """Initialize by creating all tables."""
self.engine = engine self.engine = engine
self.cache_schema = cache_schema self.cache_schema = cache_schema
@ -76,20 +84,26 @@ class SQLAlchemyCache(BaseCache):
.order_by(self.cache_schema.idx) .order_by(self.cache_schema.idx)
) )
with Session(self.engine) as session: with Session(self.engine) as session:
generations = [Generation(text=row[0]) for row in session.execute(stmt)] rows = session.execute(stmt).fetchall()
if len(generations) > 0: if rows:
return generations return [Generation(text=row[0]) for row in rows]
return None return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Look up based on prompt and llm_string.""" """Update based on prompt and llm_string."""
for i, generation in enumerate(return_val): items = [
item = self.cache_schema( self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
prompt=prompt, llm=llm_string, response=generation.text, idx=i for i, gen in enumerate(return_val)
) ]
with Session(self.engine) as session, session.begin(): with Session(self.engine) as session, session.begin():
for item in items:
session.merge(item) session.merge(item)
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
with Session(self.engine) as session:
session.execute(self.cache_schema.delete())
class SQLiteCache(SQLAlchemyCache): class SQLiteCache(SQLAlchemyCache):
"""Cache that uses SQLite as a backend.""" """Cache that uses SQLite as a backend."""
@ -139,19 +153,26 @@ class RedisCache(BaseCache):
for i, generation in enumerate(return_val): for i, generation in enumerate(return_val):
self.redis.set(self._key(prompt, llm_string, i), generation.text) self.redis.set(self._key(prompt, llm_string, i), generation.text)
def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
asynchronous = kwargs.get("asynchronous", False)
self.redis.flushdb(asynchronous=asynchronous, **kwargs)
class GPTCache(BaseCache): class GPTCache(BaseCache):
"""Cache that uses GPTCache as a backend.""" """Cache that uses GPTCache as a backend."""
def __init__(self, init_func: Callable[[Any], None]): def __init__(self, init_func: Optional[Callable[[Any], None]] = None):
"""Initialize by passing in the `init` GPTCache func """Initialize by passing in init function (default: `None`).
Args: Args:
init_func (Callable[[Any], None]): init `GPTCache` function init_func (Optional[Callable[[Any], None]]): init `GPTCache` function
(default: `None`)
Example: Example:
.. code-block:: python .. code-block:: python
# Initialize GPTCache with a custom init function
import gptcache import gptcache
from gptcache.processor.pre import get_prompt from gptcache.processor.pre import get_prompt
from gptcache.manager.factory import get_data_manager from gptcache.manager.factory import get_data_manager
@ -180,7 +201,8 @@ class GPTCache(BaseCache):
"Could not import gptcache python package. " "Could not import gptcache python package. "
"Please install it with `pip install gptcache`." "Please install it with `pip install gptcache`."
) )
self.init_gptcache_func: Callable[[Any], None] = init_func
self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func
self.gptcache_dict: Dict[str, Any] = {} self.gptcache_dict: Dict[str, Any] = {}
@staticmethod @staticmethod
@ -205,11 +227,19 @@ class GPTCache(BaseCache):
When the corresponding llm model cache does not exist, it will be created.""" When the corresponding llm model cache does not exist, it will be created."""
from gptcache import Cache from gptcache import Cache
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
_gptcache = self.gptcache_dict.get(llm_string, None) _gptcache = self.gptcache_dict.get(llm_string, None)
if _gptcache is None: if _gptcache is None:
_gptcache = Cache() _gptcache = Cache()
self.init_gptcache_func(_gptcache) if self.init_gptcache_func is not None:
self.init_gptcache_func(_gptcache)
else:
_gptcache.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=llm_string),
)
self.gptcache_dict[llm_string] = _gptcache self.gptcache_dict[llm_string] = _gptcache
return _gptcache return _gptcache
@ -220,7 +250,7 @@ class GPTCache(BaseCache):
""" """
from gptcache.adapter.adapter import adapt from gptcache.adapter.adapter import adapt
_gptcache = self.gptcache_dict.get(llm_string) _gptcache = self.gptcache_dict.get(llm_string, None)
if _gptcache is None: if _gptcache is None:
return None return None
res = adapt( res = adapt(
@ -234,7 +264,10 @@ class GPTCache(BaseCache):
@staticmethod @staticmethod
def _update_cache_callback( def _update_cache_callback(
llm_data: RETURN_VAL_TYPE, update_cache_func: Callable[[Any], None] llm_data: RETURN_VAL_TYPE,
update_cache_func: Callable[[Any], None],
*args: Any,
**kwargs: Any,
) -> None: ) -> None:
"""Save the `llm_data` to cache storage""" """Save the `llm_data` to cache storage"""
handled_data = json.dumps([generation.dict() for generation in llm_data]) handled_data = json.dumps([generation.dict() for generation in llm_data])
@ -260,3 +293,13 @@ class GPTCache(BaseCache):
cache_skip=True, cache_skip=True,
prompt=prompt, prompt=prompt,
) )
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
from gptcache import Cache
for gptcache_instance in self.gptcache_dict.values():
gptcache_instance = cast(Cache, gptcache_instance)
gptcache_instance.flush()
self.gptcache_dict.clear()

View File

@ -235,4 +235,5 @@ class ConversationEntityMemory(BaseChatMemory):
def clear(self) -> None: def clear(self) -> None:
"""Clear memory contents.""" """Clear memory contents."""
self.chat_memory.clear() self.chat_memory.clear()
self.entity_cache.clear()
self.entity_store.clear() self.entity_store.clear()

View File

@ -1,61 +1,48 @@
import os import os
from typing import Any, Callable, Optional
import pytest import pytest
import langchain import langchain
from langchain.cache import GPTCache from langchain.cache import GPTCache
from langchain.schema import Generation, LLMResult from langchain.schema import Generation
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
try: try:
import gptcache # noqa: F401 from gptcache import Cache # noqa: F401
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
gptcache_installed = True gptcache_installed = True
except ImportError: except ImportError:
gptcache_installed = False gptcache_installed = False
def init_gptcache_map(cache_obj: Cache) -> None:
i = getattr(init_gptcache_map, "_i", 0)
cache_path = f"data_map_{i}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
init_gptcache_map._i = i + 1 # type: ignore
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed") @pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
def test_gptcache_map_caching() -> None: @pytest.mark.parametrize("init_func", [None, init_gptcache_map])
"""Test gptcache caching behavior.""" def test_gptcache_caching(init_func: Optional[Callable[[Any], None]]) -> None:
"""Test gptcache default caching behavior."""
from gptcache import Cache langchain.llm_cache = GPTCache(init_func)
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
i = 0
file_prefix = "data_map"
def init_gptcache_map(cache_obj: Cache) -> None:
nonlocal i
cache_path = f"{file_prefix}_{i}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
i += 1
langchain.llm_cache = GPTCache(init_gptcache_map)
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"]) _ = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")] cache_output = langchain.llm_cache.lookup("foo", llm_string)
cache_output = langchain.llm_cache.lookup("bar", llm_string) assert cache_output == [Generation(text="fizz")]
assert cache_output == expected_cache_output
langchain.llm_cache = None langchain.llm_cache.clear()
expected_generations = [ assert langchain.llm_cache.lookup("bar", llm_string) is None
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
generations=expected_generations,
llm_output=None,
)
assert output == expected_output