Optimize the initialization method of GPTCache (#4522)

Optimize the initialization method of GPTCache, so that users can use GPTCache more quickly.
This commit is contained in:
SimFG 2023-05-12 07:15:23 +08:00 committed by GitHub
parent f4d3cf2dfb
commit 7bcf238a1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 67 deletions

View File

@ -408,25 +408,20 @@
"metadata": {},
"outputs": [],
"source": [
"import gptcache\n",
"from gptcache import Cache\n",
"from gptcache.manager.factory import manager_factory\n",
"from gptcache.processor.pre import get_prompt\n",
"from gptcache.manager.factory import get_data_manager\n",
"from langchain.cache import GPTCache\n",
"\n",
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
"i = 0\n",
"file_prefix = \"data_map\"\n",
"\n",
"def init_gptcache_map(cache_obj: gptcache.Cache):\n",
" global i\n",
" cache_path = f'{file_prefix}_{i}.txt'\n",
"def init_gptcache(cache_obj: Cache, llm str):\n",
" cache_obj.init(\n",
" pre_embedding_func=get_prompt,\n",
" data_manager=get_data_manager(data_path=cache_path),\n",
" data_manager=manager_factory(manager=\"map\", data_dir=f\"map_cache_{llm}\"),\n",
" )\n",
" i += 1\n",
"\n",
"langchain.llm_cache = GPTCache(init_gptcache_map)"
"langchain.llm_cache = GPTCache(init_gptcache)"
]
},
{
@ -506,37 +501,16 @@
"metadata": {},
"outputs": [],
"source": [
"import gptcache\n",
"from gptcache.processor.pre import get_prompt\n",
"from gptcache.manager.factory import get_data_manager\n",
"from langchain.cache import GPTCache\n",
"from gptcache.manager import get_data_manager, CacheBase, VectorBase\n",
"from gptcache import Cache\n",
"from gptcache.embedding import Onnx\n",
"from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation\n",
"from gptcache.adapter.api import init_similar_cache\n",
"from langchain.cache import GPTCache\n",
"\n",
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
"i = 0\n",
"file_prefix = \"data_map\"\n",
"llm_cache = Cache()\n",
"\n",
"def init_gptcache(cache_obj: Cache, llm str):\n",
" init_similar_cache(cache_obj=cache_obj, data_dir=f\"similar_cache_{llm}\")\n",
"\n",
"def init_gptcache_map(cache_obj: gptcache.Cache):\n",
" global i\n",
" cache_path = f'{file_prefix}_{i}.txt'\n",
" onnx = Onnx()\n",
" cache_base = CacheBase('sqlite')\n",
" vector_base = VectorBase('faiss', dimension=onnx.dimension)\n",
" data_manager = get_data_manager(cache_base, vector_base, max_size=10, clean_size=2)\n",
" cache_obj.init(\n",
" pre_embedding_func=get_prompt,\n",
" embedding_func=onnx.to_embeddings,\n",
" data_manager=data_manager,\n",
" similarity_evaluation=SearchDistanceEvaluation(),\n",
" )\n",
" i += 1\n",
"\n",
"langchain.llm_cache = GPTCache(init_gptcache_map)"
"langchain.llm_cache = GPTCache(init_gptcache)"
]
},
{
@ -929,7 +903,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@ -943,7 +917,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.8.8"
}
},
"nbformat": 4,

View File

@ -1,8 +1,9 @@
"""Beta Feature: base interface for cache."""
import hashlib
import inspect
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine
@ -274,7 +275,12 @@ class RedisSemanticCache(BaseCache):
class GPTCache(BaseCache):
"""Cache that uses GPTCache as a backend."""
def __init__(self, init_func: Optional[Callable[[Any], None]] = None):
def __init__(
self,
init_func: Union[
Callable[[Any, str], None], Callable[[Any], None], None
] = None,
):
"""Initialize by passing in init function (default: `None`).
Args:
@ -291,19 +297,17 @@ class GPTCache(BaseCache):
# Avoid multiple caches using the same file,
causing different llm model caches to affect each other
i = 0
file_prefix = "data_map"
def init_gptcache_map(cache_obj: gptcache.Cache):
nonlocal i
cache_path = f'{file_prefix}_{i}.txt'
def init_gptcache(cache_obj: gptcache.Cache, llm str):
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
data_manager=manager_factory(
manager="map",
data_dir=f"map_cache_{llm}"
),
)
i += 1
langchain.llm_cache = GPTCache(init_gptcache_map)
langchain.llm_cache = GPTCache(init_gptcache)
"""
try:
@ -314,29 +318,37 @@ class GPTCache(BaseCache):
"Please install it with `pip install gptcache`."
)
self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func
self.init_gptcache_func: Union[
Callable[[Any, str], None], Callable[[Any], None], None
] = init_func
self.gptcache_dict: Dict[str, Any] = {}
def _new_gptcache(self, llm_string: str) -> Any:
"""New gptcache object"""
from gptcache import Cache
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
_gptcache = Cache()
if self.init_gptcache_func is not None:
sig = inspect.signature(self.init_gptcache_func)
if len(sig.parameters) == 2:
self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg]
else:
self.init_gptcache_func(_gptcache) # type: ignore[call-arg]
else:
_gptcache.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=llm_string),
)
return _gptcache
def _get_gptcache(self, llm_string: str) -> Any:
"""Get a cache object.
When the corresponding llm model cache does not exist, it will be created."""
from gptcache import Cache
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
_gptcache = self.gptcache_dict.get(llm_string, None)
if _gptcache is None:
_gptcache = Cache()
if self.init_gptcache_func is not None:
self.init_gptcache_func(_gptcache)
else:
_gptcache.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=llm_string),
)
self.gptcache_dict[llm_string] = _gptcache
return _gptcache
return self.gptcache_dict.get(llm_string, self._new_gptcache(llm_string))
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up the cache data.

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Callable, Optional
from typing import Any, Callable, Union
import pytest
@ -30,9 +30,23 @@ def init_gptcache_map(cache_obj: Cache) -> None:
init_gptcache_map._i = i + 1 # type: ignore
def init_gptcache_map_with_llm(cache_obj: Cache, llm: str) -> None:
cache_path = f"data_map_{llm}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
@pytest.mark.parametrize("init_func", [None, init_gptcache_map])
def test_gptcache_caching(init_func: Optional[Callable[[Any], None]]) -> None:
@pytest.mark.parametrize(
"init_func", [None, init_gptcache_map, init_gptcache_map_with_llm]
)
def test_gptcache_caching(
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None]
) -> None:
"""Test gptcache default caching behavior."""
langchain.llm_cache = GPTCache(init_func)
llm = FakeLLM()