mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 03:28:40 +00:00
Optimize the initialization method of GPTCache (#4522)
Optimize the initialization method of GPTCache, so that users can use GPTCache more quickly.
This commit is contained in:
parent
f4d3cf2dfb
commit
7bcf238a1a
@ -408,25 +408,20 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gptcache\n",
|
||||
"from gptcache import Cache\n",
|
||||
"from gptcache.manager.factory import manager_factory\n",
|
||||
"from gptcache.processor.pre import get_prompt\n",
|
||||
"from gptcache.manager.factory import get_data_manager\n",
|
||||
"from langchain.cache import GPTCache\n",
|
||||
"\n",
|
||||
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
|
||||
"i = 0\n",
|
||||
"file_prefix = \"data_map\"\n",
|
||||
"\n",
|
||||
"def init_gptcache_map(cache_obj: gptcache.Cache):\n",
|
||||
" global i\n",
|
||||
" cache_path = f'{file_prefix}_{i}.txt'\n",
|
||||
"def init_gptcache(cache_obj: Cache, llm str):\n",
|
||||
" cache_obj.init(\n",
|
||||
" pre_embedding_func=get_prompt,\n",
|
||||
" data_manager=get_data_manager(data_path=cache_path),\n",
|
||||
" data_manager=manager_factory(manager=\"map\", data_dir=f\"map_cache_{llm}\"),\n",
|
||||
" )\n",
|
||||
" i += 1\n",
|
||||
"\n",
|
||||
"langchain.llm_cache = GPTCache(init_gptcache_map)"
|
||||
"langchain.llm_cache = GPTCache(init_gptcache)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -506,37 +501,16 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gptcache\n",
|
||||
"from gptcache.processor.pre import get_prompt\n",
|
||||
"from gptcache.manager.factory import get_data_manager\n",
|
||||
"from langchain.cache import GPTCache\n",
|
||||
"from gptcache.manager import get_data_manager, CacheBase, VectorBase\n",
|
||||
"from gptcache import Cache\n",
|
||||
"from gptcache.embedding import Onnx\n",
|
||||
"from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation\n",
|
||||
"from gptcache.adapter.api import init_similar_cache\n",
|
||||
"from langchain.cache import GPTCache\n",
|
||||
"\n",
|
||||
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
|
||||
"i = 0\n",
|
||||
"file_prefix = \"data_map\"\n",
|
||||
"llm_cache = Cache()\n",
|
||||
"\n",
|
||||
"def init_gptcache(cache_obj: Cache, llm str):\n",
|
||||
" init_similar_cache(cache_obj=cache_obj, data_dir=f\"similar_cache_{llm}\")\n",
|
||||
"\n",
|
||||
"def init_gptcache_map(cache_obj: gptcache.Cache):\n",
|
||||
" global i\n",
|
||||
" cache_path = f'{file_prefix}_{i}.txt'\n",
|
||||
" onnx = Onnx()\n",
|
||||
" cache_base = CacheBase('sqlite')\n",
|
||||
" vector_base = VectorBase('faiss', dimension=onnx.dimension)\n",
|
||||
" data_manager = get_data_manager(cache_base, vector_base, max_size=10, clean_size=2)\n",
|
||||
" cache_obj.init(\n",
|
||||
" pre_embedding_func=get_prompt,\n",
|
||||
" embedding_func=onnx.to_embeddings,\n",
|
||||
" data_manager=data_manager,\n",
|
||||
" similarity_evaluation=SearchDistanceEvaluation(),\n",
|
||||
" )\n",
|
||||
" i += 1\n",
|
||||
"\n",
|
||||
"langchain.llm_cache = GPTCache(init_gptcache_map)"
|
||||
"langchain.llm_cache = GPTCache(init_gptcache)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -929,7 +903,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -943,7 +917,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -1,8 +1,9 @@
|
||||
"""Beta Feature: base interface for cache."""
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from sqlalchemy import Column, Integer, String, create_engine, select
|
||||
from sqlalchemy.engine.base import Engine
|
||||
@ -274,7 +275,12 @@ class RedisSemanticCache(BaseCache):
|
||||
class GPTCache(BaseCache):
|
||||
"""Cache that uses GPTCache as a backend."""
|
||||
|
||||
def __init__(self, init_func: Optional[Callable[[Any], None]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
init_func: Union[
|
||||
Callable[[Any, str], None], Callable[[Any], None], None
|
||||
] = None,
|
||||
):
|
||||
"""Initialize by passing in init function (default: `None`).
|
||||
|
||||
Args:
|
||||
@ -291,19 +297,17 @@ class GPTCache(BaseCache):
|
||||
|
||||
# Avoid multiple caches using the same file,
|
||||
causing different llm model caches to affect each other
|
||||
i = 0
|
||||
file_prefix = "data_map"
|
||||
|
||||
def init_gptcache_map(cache_obj: gptcache.Cache):
|
||||
nonlocal i
|
||||
cache_path = f'{file_prefix}_{i}.txt'
|
||||
def init_gptcache(cache_obj: gptcache.Cache, llm str):
|
||||
cache_obj.init(
|
||||
pre_embedding_func=get_prompt,
|
||||
data_manager=get_data_manager(data_path=cache_path),
|
||||
data_manager=manager_factory(
|
||||
manager="map",
|
||||
data_dir=f"map_cache_{llm}"
|
||||
),
|
||||
)
|
||||
i += 1
|
||||
|
||||
langchain.llm_cache = GPTCache(init_gptcache_map)
|
||||
langchain.llm_cache = GPTCache(init_gptcache)
|
||||
|
||||
"""
|
||||
try:
|
||||
@ -314,29 +318,37 @@ class GPTCache(BaseCache):
|
||||
"Please install it with `pip install gptcache`."
|
||||
)
|
||||
|
||||
self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func
|
||||
self.init_gptcache_func: Union[
|
||||
Callable[[Any, str], None], Callable[[Any], None], None
|
||||
] = init_func
|
||||
self.gptcache_dict: Dict[str, Any] = {}
|
||||
|
||||
def _new_gptcache(self, llm_string: str) -> Any:
|
||||
"""New gptcache object"""
|
||||
from gptcache import Cache
|
||||
from gptcache.manager.factory import get_data_manager
|
||||
from gptcache.processor.pre import get_prompt
|
||||
|
||||
_gptcache = Cache()
|
||||
if self.init_gptcache_func is not None:
|
||||
sig = inspect.signature(self.init_gptcache_func)
|
||||
if len(sig.parameters) == 2:
|
||||
self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg]
|
||||
else:
|
||||
self.init_gptcache_func(_gptcache) # type: ignore[call-arg]
|
||||
else:
|
||||
_gptcache.init(
|
||||
pre_embedding_func=get_prompt,
|
||||
data_manager=get_data_manager(data_path=llm_string),
|
||||
)
|
||||
return _gptcache
|
||||
|
||||
def _get_gptcache(self, llm_string: str) -> Any:
|
||||
"""Get a cache object.
|
||||
|
||||
When the corresponding llm model cache does not exist, it will be created."""
|
||||
from gptcache import Cache
|
||||
from gptcache.manager.factory import get_data_manager
|
||||
from gptcache.processor.pre import get_prompt
|
||||
|
||||
_gptcache = self.gptcache_dict.get(llm_string, None)
|
||||
if _gptcache is None:
|
||||
_gptcache = Cache()
|
||||
if self.init_gptcache_func is not None:
|
||||
self.init_gptcache_func(_gptcache)
|
||||
else:
|
||||
_gptcache.init(
|
||||
pre_embedding_func=get_prompt,
|
||||
data_manager=get_data_manager(data_path=llm_string),
|
||||
)
|
||||
self.gptcache_dict[llm_string] = _gptcache
|
||||
return _gptcache
|
||||
return self.gptcache_dict.get(llm_string, self._new_gptcache(llm_string))
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up the cache data.
|
||||
|
20
tests/integration_tests/cache/test_gptcache.py
vendored
20
tests/integration_tests/cache/test_gptcache.py
vendored
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@ -30,9 +30,23 @@ def init_gptcache_map(cache_obj: Cache) -> None:
|
||||
init_gptcache_map._i = i + 1 # type: ignore
|
||||
|
||||
|
||||
def init_gptcache_map_with_llm(cache_obj: Cache, llm: str) -> None:
|
||||
cache_path = f"data_map_{llm}.txt"
|
||||
if os.path.isfile(cache_path):
|
||||
os.remove(cache_path)
|
||||
cache_obj.init(
|
||||
pre_embedding_func=get_prompt,
|
||||
data_manager=get_data_manager(data_path=cache_path),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
|
||||
@pytest.mark.parametrize("init_func", [None, init_gptcache_map])
|
||||
def test_gptcache_caching(init_func: Optional[Callable[[Any], None]]) -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"init_func", [None, init_gptcache_map, init_gptcache_map_with_llm]
|
||||
)
|
||||
def test_gptcache_caching(
|
||||
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None]
|
||||
) -> None:
|
||||
"""Test gptcache default caching behavior."""
|
||||
langchain.llm_cache = GPTCache(init_func)
|
||||
llm = FakeLLM()
|
||||
|
Loading…
Reference in New Issue
Block a user