Optimize the initialization method of GPTCache (#4522)

Optimize the initialization method of GPTCache, so that users can use GPTCache more quickly.
This commit is contained in:
SimFG 2023-05-12 07:15:23 +08:00 committed by GitHub
parent f4d3cf2dfb
commit 7bcf238a1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 67 deletions

View File

@ -408,25 +408,20 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import gptcache\n", "from gptcache import Cache\n",
"from gptcache.manager.factory import manager_factory\n",
"from gptcache.processor.pre import get_prompt\n", "from gptcache.processor.pre import get_prompt\n",
"from gptcache.manager.factory import get_data_manager\n",
"from langchain.cache import GPTCache\n", "from langchain.cache import GPTCache\n",
"\n", "\n",
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n", "# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
"i = 0\n",
"file_prefix = \"data_map\"\n",
"\n", "\n",
"def init_gptcache_map(cache_obj: gptcache.Cache):\n", "def init_gptcache(cache_obj: Cache, llm str):\n",
" global i\n",
" cache_path = f'{file_prefix}_{i}.txt'\n",
" cache_obj.init(\n", " cache_obj.init(\n",
" pre_embedding_func=get_prompt,\n", " pre_embedding_func=get_prompt,\n",
" data_manager=get_data_manager(data_path=cache_path),\n", " data_manager=manager_factory(manager=\"map\", data_dir=f\"map_cache_{llm}\"),\n",
" )\n", " )\n",
" i += 1\n",
"\n", "\n",
"langchain.llm_cache = GPTCache(init_gptcache_map)" "langchain.llm_cache = GPTCache(init_gptcache)"
] ]
}, },
{ {
@ -506,37 +501,16 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import gptcache\n",
"from gptcache.processor.pre import get_prompt\n",
"from gptcache.manager.factory import get_data_manager\n",
"from langchain.cache import GPTCache\n",
"from gptcache.manager import get_data_manager, CacheBase, VectorBase\n",
"from gptcache import Cache\n", "from gptcache import Cache\n",
"from gptcache.embedding import Onnx\n", "from gptcache.adapter.api import init_similar_cache\n",
"from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation\n", "from langchain.cache import GPTCache\n",
"\n", "\n",
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n", "# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
"i = 0\n",
"file_prefix = \"data_map\"\n",
"llm_cache = Cache()\n",
"\n", "\n",
"def init_gptcache(cache_obj: Cache, llm str):\n",
" init_similar_cache(cache_obj=cache_obj, data_dir=f\"similar_cache_{llm}\")\n",
"\n", "\n",
"def init_gptcache_map(cache_obj: gptcache.Cache):\n", "langchain.llm_cache = GPTCache(init_gptcache)"
" global i\n",
" cache_path = f'{file_prefix}_{i}.txt'\n",
" onnx = Onnx()\n",
" cache_base = CacheBase('sqlite')\n",
" vector_base = VectorBase('faiss', dimension=onnx.dimension)\n",
" data_manager = get_data_manager(cache_base, vector_base, max_size=10, clean_size=2)\n",
" cache_obj.init(\n",
" pre_embedding_func=get_prompt,\n",
" embedding_func=onnx.to_embeddings,\n",
" data_manager=data_manager,\n",
" similarity_evaluation=SearchDistanceEvaluation(),\n",
" )\n",
" i += 1\n",
"\n",
"langchain.llm_cache = GPTCache(init_gptcache_map)"
] ]
}, },
{ {
@ -929,7 +903,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -943,7 +917,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.8.8"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,8 +1,9 @@
"""Beta Feature: base interface for cache.""" """Beta Feature: base interface for cache."""
import hashlib import hashlib
import inspect
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
@ -274,7 +275,12 @@ class RedisSemanticCache(BaseCache):
class GPTCache(BaseCache): class GPTCache(BaseCache):
"""Cache that uses GPTCache as a backend.""" """Cache that uses GPTCache as a backend."""
def __init__(self, init_func: Optional[Callable[[Any], None]] = None): def __init__(
self,
init_func: Union[
Callable[[Any, str], None], Callable[[Any], None], None
] = None,
):
"""Initialize by passing in init function (default: `None`). """Initialize by passing in init function (default: `None`).
Args: Args:
@ -291,19 +297,17 @@ class GPTCache(BaseCache):
# Avoid multiple caches using the same file, # Avoid multiple caches using the same file,
causing different llm model caches to affect each other causing different llm model caches to affect each other
i = 0
file_prefix = "data_map"
def init_gptcache_map(cache_obj: gptcache.Cache): def init_gptcache(cache_obj: gptcache.Cache, llm str):
nonlocal i
cache_path = f'{file_prefix}_{i}.txt'
cache_obj.init( cache_obj.init(
pre_embedding_func=get_prompt, pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path), data_manager=manager_factory(
manager="map",
data_dir=f"map_cache_{llm}"
),
) )
i += 1
langchain.llm_cache = GPTCache(init_gptcache_map) langchain.llm_cache = GPTCache(init_gptcache)
""" """
try: try:
@ -314,29 +318,37 @@ class GPTCache(BaseCache):
"Please install it with `pip install gptcache`." "Please install it with `pip install gptcache`."
) )
self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func self.init_gptcache_func: Union[
Callable[[Any, str], None], Callable[[Any], None], None
] = init_func
self.gptcache_dict: Dict[str, Any] = {} self.gptcache_dict: Dict[str, Any] = {}
def _new_gptcache(self, llm_string: str) -> Any:
"""New gptcache object"""
from gptcache import Cache
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
_gptcache = Cache()
if self.init_gptcache_func is not None:
sig = inspect.signature(self.init_gptcache_func)
if len(sig.parameters) == 2:
self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg]
else:
self.init_gptcache_func(_gptcache) # type: ignore[call-arg]
else:
_gptcache.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=llm_string),
)
return _gptcache
def _get_gptcache(self, llm_string: str) -> Any: def _get_gptcache(self, llm_string: str) -> Any:
"""Get a cache object. """Get a cache object.
When the corresponding llm model cache does not exist, it will be created.""" When the corresponding llm model cache does not exist, it will be created."""
from gptcache import Cache
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
_gptcache = self.gptcache_dict.get(llm_string, None) return self.gptcache_dict.get(llm_string, self._new_gptcache(llm_string))
if _gptcache is None:
_gptcache = Cache()
if self.init_gptcache_func is not None:
self.init_gptcache_func(_gptcache)
else:
_gptcache.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=llm_string),
)
self.gptcache_dict[llm_string] = _gptcache
return _gptcache
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up the cache data. """Look up the cache data.

View File

@ -1,5 +1,5 @@
import os import os
from typing import Any, Callable, Optional from typing import Any, Callable, Union
import pytest import pytest
@ -30,9 +30,23 @@ def init_gptcache_map(cache_obj: Cache) -> None:
init_gptcache_map._i = i + 1 # type: ignore init_gptcache_map._i = i + 1 # type: ignore
def init_gptcache_map_with_llm(cache_obj: Cache, llm: str) -> None:
cache_path = f"data_map_{llm}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed") @pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
@pytest.mark.parametrize("init_func", [None, init_gptcache_map]) @pytest.mark.parametrize(
def test_gptcache_caching(init_func: Optional[Callable[[Any], None]]) -> None: "init_func", [None, init_gptcache_map, init_gptcache_map_with_llm]
)
def test_gptcache_caching(
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None]
) -> None:
"""Test gptcache default caching behavior.""" """Test gptcache default caching behavior."""
langchain.llm_cache = GPTCache(init_func) langchain.llm_cache = GPTCache(init_func)
llm = FakeLLM() llm = FakeLLM()