mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 21:09:00 +00:00
Optimize the initialization method of GPTCache (#4522)
Optimize the initialization method of GPTCache, so that users can use GPTCache more quickly.
This commit is contained in:
parent
f4d3cf2dfb
commit
7bcf238a1a
@ -408,25 +408,20 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import gptcache\n",
|
"from gptcache import Cache\n",
|
||||||
|
"from gptcache.manager.factory import manager_factory\n",
|
||||||
"from gptcache.processor.pre import get_prompt\n",
|
"from gptcache.processor.pre import get_prompt\n",
|
||||||
"from gptcache.manager.factory import get_data_manager\n",
|
|
||||||
"from langchain.cache import GPTCache\n",
|
"from langchain.cache import GPTCache\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
|
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
|
||||||
"i = 0\n",
|
|
||||||
"file_prefix = \"data_map\"\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"def init_gptcache_map(cache_obj: gptcache.Cache):\n",
|
"def init_gptcache(cache_obj: Cache, llm str):\n",
|
||||||
" global i\n",
|
|
||||||
" cache_path = f'{file_prefix}_{i}.txt'\n",
|
|
||||||
" cache_obj.init(\n",
|
" cache_obj.init(\n",
|
||||||
" pre_embedding_func=get_prompt,\n",
|
" pre_embedding_func=get_prompt,\n",
|
||||||
" data_manager=get_data_manager(data_path=cache_path),\n",
|
" data_manager=manager_factory(manager=\"map\", data_dir=f\"map_cache_{llm}\"),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" i += 1\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"langchain.llm_cache = GPTCache(init_gptcache_map)"
|
"langchain.llm_cache = GPTCache(init_gptcache)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -506,37 +501,16 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import gptcache\n",
|
|
||||||
"from gptcache.processor.pre import get_prompt\n",
|
|
||||||
"from gptcache.manager.factory import get_data_manager\n",
|
|
||||||
"from langchain.cache import GPTCache\n",
|
|
||||||
"from gptcache.manager import get_data_manager, CacheBase, VectorBase\n",
|
|
||||||
"from gptcache import Cache\n",
|
"from gptcache import Cache\n",
|
||||||
"from gptcache.embedding import Onnx\n",
|
"from gptcache.adapter.api import init_similar_cache\n",
|
||||||
"from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation\n",
|
"from langchain.cache import GPTCache\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
|
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
|
||||||
"i = 0\n",
|
|
||||||
"file_prefix = \"data_map\"\n",
|
|
||||||
"llm_cache = Cache()\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
|
"def init_gptcache(cache_obj: Cache, llm str):\n",
|
||||||
|
" init_similar_cache(cache_obj=cache_obj, data_dir=f\"similar_cache_{llm}\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def init_gptcache_map(cache_obj: gptcache.Cache):\n",
|
"langchain.llm_cache = GPTCache(init_gptcache)"
|
||||||
" global i\n",
|
|
||||||
" cache_path = f'{file_prefix}_{i}.txt'\n",
|
|
||||||
" onnx = Onnx()\n",
|
|
||||||
" cache_base = CacheBase('sqlite')\n",
|
|
||||||
" vector_base = VectorBase('faiss', dimension=onnx.dimension)\n",
|
|
||||||
" data_manager = get_data_manager(cache_base, vector_base, max_size=10, clean_size=2)\n",
|
|
||||||
" cache_obj.init(\n",
|
|
||||||
" pre_embedding_func=get_prompt,\n",
|
|
||||||
" embedding_func=onnx.to_embeddings,\n",
|
|
||||||
" data_manager=data_manager,\n",
|
|
||||||
" similarity_evaluation=SearchDistanceEvaluation(),\n",
|
|
||||||
" )\n",
|
|
||||||
" i += 1\n",
|
|
||||||
"\n",
|
|
||||||
"langchain.llm_cache = GPTCache(init_gptcache_map)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -929,7 +903,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -943,7 +917,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.1"
|
"version": "3.8.8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
"""Beta Feature: base interface for cache."""
|
"""Beta Feature: base interface for cache."""
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, create_engine, select
|
from sqlalchemy import Column, Integer, String, create_engine, select
|
||||||
from sqlalchemy.engine.base import Engine
|
from sqlalchemy.engine.base import Engine
|
||||||
@ -274,7 +275,12 @@ class RedisSemanticCache(BaseCache):
|
|||||||
class GPTCache(BaseCache):
|
class GPTCache(BaseCache):
|
||||||
"""Cache that uses GPTCache as a backend."""
|
"""Cache that uses GPTCache as a backend."""
|
||||||
|
|
||||||
def __init__(self, init_func: Optional[Callable[[Any], None]] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
init_func: Union[
|
||||||
|
Callable[[Any, str], None], Callable[[Any], None], None
|
||||||
|
] = None,
|
||||||
|
):
|
||||||
"""Initialize by passing in init function (default: `None`).
|
"""Initialize by passing in init function (default: `None`).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -291,19 +297,17 @@ class GPTCache(BaseCache):
|
|||||||
|
|
||||||
# Avoid multiple caches using the same file,
|
# Avoid multiple caches using the same file,
|
||||||
causing different llm model caches to affect each other
|
causing different llm model caches to affect each other
|
||||||
i = 0
|
|
||||||
file_prefix = "data_map"
|
|
||||||
|
|
||||||
def init_gptcache_map(cache_obj: gptcache.Cache):
|
def init_gptcache(cache_obj: gptcache.Cache, llm str):
|
||||||
nonlocal i
|
|
||||||
cache_path = f'{file_prefix}_{i}.txt'
|
|
||||||
cache_obj.init(
|
cache_obj.init(
|
||||||
pre_embedding_func=get_prompt,
|
pre_embedding_func=get_prompt,
|
||||||
data_manager=get_data_manager(data_path=cache_path),
|
data_manager=manager_factory(
|
||||||
|
manager="map",
|
||||||
|
data_dir=f"map_cache_{llm}"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
i += 1
|
|
||||||
|
|
||||||
langchain.llm_cache = GPTCache(init_gptcache_map)
|
langchain.llm_cache = GPTCache(init_gptcache)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@ -314,29 +318,37 @@ class GPTCache(BaseCache):
|
|||||||
"Please install it with `pip install gptcache`."
|
"Please install it with `pip install gptcache`."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func
|
self.init_gptcache_func: Union[
|
||||||
|
Callable[[Any, str], None], Callable[[Any], None], None
|
||||||
|
] = init_func
|
||||||
self.gptcache_dict: Dict[str, Any] = {}
|
self.gptcache_dict: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def _new_gptcache(self, llm_string: str) -> Any:
|
||||||
|
"""New gptcache object"""
|
||||||
|
from gptcache import Cache
|
||||||
|
from gptcache.manager.factory import get_data_manager
|
||||||
|
from gptcache.processor.pre import get_prompt
|
||||||
|
|
||||||
|
_gptcache = Cache()
|
||||||
|
if self.init_gptcache_func is not None:
|
||||||
|
sig = inspect.signature(self.init_gptcache_func)
|
||||||
|
if len(sig.parameters) == 2:
|
||||||
|
self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg]
|
||||||
|
else:
|
||||||
|
self.init_gptcache_func(_gptcache) # type: ignore[call-arg]
|
||||||
|
else:
|
||||||
|
_gptcache.init(
|
||||||
|
pre_embedding_func=get_prompt,
|
||||||
|
data_manager=get_data_manager(data_path=llm_string),
|
||||||
|
)
|
||||||
|
return _gptcache
|
||||||
|
|
||||||
def _get_gptcache(self, llm_string: str) -> Any:
|
def _get_gptcache(self, llm_string: str) -> Any:
|
||||||
"""Get a cache object.
|
"""Get a cache object.
|
||||||
|
|
||||||
When the corresponding llm model cache does not exist, it will be created."""
|
When the corresponding llm model cache does not exist, it will be created."""
|
||||||
from gptcache import Cache
|
|
||||||
from gptcache.manager.factory import get_data_manager
|
|
||||||
from gptcache.processor.pre import get_prompt
|
|
||||||
|
|
||||||
_gptcache = self.gptcache_dict.get(llm_string, None)
|
return self.gptcache_dict.get(llm_string, self._new_gptcache(llm_string))
|
||||||
if _gptcache is None:
|
|
||||||
_gptcache = Cache()
|
|
||||||
if self.init_gptcache_func is not None:
|
|
||||||
self.init_gptcache_func(_gptcache)
|
|
||||||
else:
|
|
||||||
_gptcache.init(
|
|
||||||
pre_embedding_func=get_prompt,
|
|
||||||
data_manager=get_data_manager(data_path=llm_string),
|
|
||||||
)
|
|
||||||
self.gptcache_dict[llm_string] = _gptcache
|
|
||||||
return _gptcache
|
|
||||||
|
|
||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
"""Look up the cache data.
|
"""Look up the cache data.
|
||||||
|
20
tests/integration_tests/cache/test_gptcache.py
vendored
20
tests/integration_tests/cache/test_gptcache.py
vendored
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -30,9 +30,23 @@ def init_gptcache_map(cache_obj: Cache) -> None:
|
|||||||
init_gptcache_map._i = i + 1 # type: ignore
|
init_gptcache_map._i = i + 1 # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def init_gptcache_map_with_llm(cache_obj: Cache, llm: str) -> None:
|
||||||
|
cache_path = f"data_map_{llm}.txt"
|
||||||
|
if os.path.isfile(cache_path):
|
||||||
|
os.remove(cache_path)
|
||||||
|
cache_obj.init(
|
||||||
|
pre_embedding_func=get_prompt,
|
||||||
|
data_manager=get_data_manager(data_path=cache_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
|
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
|
||||||
@pytest.mark.parametrize("init_func", [None, init_gptcache_map])
|
@pytest.mark.parametrize(
|
||||||
def test_gptcache_caching(init_func: Optional[Callable[[Any], None]]) -> None:
|
"init_func", [None, init_gptcache_map, init_gptcache_map_with_llm]
|
||||||
|
)
|
||||||
|
def test_gptcache_caching(
|
||||||
|
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None]
|
||||||
|
) -> None:
|
||||||
"""Test gptcache default caching behavior."""
|
"""Test gptcache default caching behavior."""
|
||||||
langchain.llm_cache = GPTCache(init_func)
|
langchain.llm_cache = GPTCache(init_func)
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
|
Loading…
Reference in New Issue
Block a user