mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-23 07:57:16 +00:00
use get_llm_cache and set_llm_cache (#11741)
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
f3ad22e64a
commit
4a2f0c51a1
@ -463,15 +463,15 @@ class RedisSemanticCache(BaseCache):
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import langchain
|
from langchain.globals import set_llm_cache
|
||||||
|
|
||||||
from langchain.cache import RedisSemanticCache
|
from langchain.cache import RedisSemanticCache
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
langchain.llm_cache = RedisSemanticCache(
|
set_llm_cache(RedisSemanticCache(
|
||||||
redis_url="redis://localhost:6379",
|
redis_url="redis://localhost:6379",
|
||||||
embedding=OpenAIEmbeddings()
|
embedding=OpenAIEmbeddings()
|
||||||
)
|
))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self._cache_dict: Dict[str, RedisVectorstore] = {}
|
self._cache_dict: Dict[str, RedisVectorstore] = {}
|
||||||
@ -588,6 +588,7 @@ class GPTCache(BaseCache):
|
|||||||
import gptcache
|
import gptcache
|
||||||
from gptcache.processor.pre import get_prompt
|
from gptcache.processor.pre import get_prompt
|
||||||
from gptcache.manager.factory import get_data_manager
|
from gptcache.manager.factory import get_data_manager
|
||||||
|
from langchain.globals import set_llm_cache
|
||||||
|
|
||||||
# Avoid multiple caches using the same file,
|
# Avoid multiple caches using the same file,
|
||||||
causing different llm model caches to affect each other
|
causing different llm model caches to affect each other
|
||||||
@ -601,7 +602,7 @@ class GPTCache(BaseCache):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
langchain.llm_cache = GPTCache(init_gptcache)
|
set_llm_cache(GPTCache(init_gptcache))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
@ -15,7 +15,6 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManager,
|
AsyncCallbackManager,
|
||||||
@ -24,6 +23,7 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
|
from langchain.globals import get_llm_cache
|
||||||
from langchain.load.dump import dumpd, dumps
|
from langchain.load.dump import dumpd, dumps
|
||||||
from langchain.prompts.base import StringPromptValue
|
from langchain.prompts.base import StringPromptValue
|
||||||
from langchain.prompts.chat import ChatPromptValue
|
from langchain.prompts.chat import ChatPromptValue
|
||||||
@ -487,7 +487,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
"run_manager"
|
"run_manager"
|
||||||
)
|
)
|
||||||
disregard_cache = self.cache is not None and not self.cache
|
disregard_cache = self.cache is not None and not self.cache
|
||||||
if langchain.llm_cache is None or disregard_cache:
|
llm_cache = get_llm_cache()
|
||||||
|
if llm_cache is None or disregard_cache:
|
||||||
# This happens when langchain.cache is None, but self.cache is True
|
# This happens when langchain.cache is None, but self.cache is True
|
||||||
if self.cache is not None and self.cache:
|
if self.cache is not None and self.cache:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -502,7 +503,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
else:
|
else:
|
||||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||||
prompt = dumps(messages)
|
prompt = dumps(messages)
|
||||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||||
if isinstance(cache_val, list):
|
if isinstance(cache_val, list):
|
||||||
return ChatResult(generations=cache_val)
|
return ChatResult(generations=cache_val)
|
||||||
else:
|
else:
|
||||||
@ -512,7 +513,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = self._generate(messages, stop=stop, **kwargs)
|
result = self._generate(messages, stop=stop, **kwargs)
|
||||||
langchain.llm_cache.update(prompt, llm_string, result.generations)
|
llm_cache.update(prompt, llm_string, result.generations)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _agenerate_with_cache(
|
async def _agenerate_with_cache(
|
||||||
@ -526,7 +527,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
"run_manager"
|
"run_manager"
|
||||||
)
|
)
|
||||||
disregard_cache = self.cache is not None and not self.cache
|
disregard_cache = self.cache is not None and not self.cache
|
||||||
if langchain.llm_cache is None or disregard_cache:
|
llm_cache = get_llm_cache()
|
||||||
|
if llm_cache is None or disregard_cache:
|
||||||
# This happens when langchain.cache is None, but self.cache is True
|
# This happens when langchain.cache is None, but self.cache is True
|
||||||
if self.cache is not None and self.cache:
|
if self.cache is not None and self.cache:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -541,7 +543,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
else:
|
else:
|
||||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||||
prompt = dumps(messages)
|
prompt = dumps(messages)
|
||||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||||
if isinstance(cache_val, list):
|
if isinstance(cache_val, list):
|
||||||
return ChatResult(generations=cache_val)
|
return ChatResult(generations=cache_val)
|
||||||
else:
|
else:
|
||||||
@ -551,7 +553,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||||
langchain.llm_cache.update(prompt, llm_string, result.generations)
|
llm_cache.update(prompt, llm_string, result.generations)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -121,7 +121,7 @@ def get_debug() -> bool:
|
|||||||
return _debug or old_debug
|
return _debug or old_debug
|
||||||
|
|
||||||
|
|
||||||
def set_llm_cache(value: "BaseCache") -> None:
|
def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
||||||
"""Set a new LLM cache, overwriting the previous value, if any."""
|
"""Set a new LLM cache, overwriting the previous value, if any."""
|
||||||
import langchain
|
import langchain
|
||||||
|
|
||||||
|
@ -37,7 +37,6 @@ from tenacity import (
|
|||||||
wait_exponential,
|
wait_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManager,
|
AsyncCallbackManager,
|
||||||
@ -46,6 +45,7 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
|
from langchain.globals import get_llm_cache
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.prompts.base import StringPromptValue
|
from langchain.prompts.base import StringPromptValue
|
||||||
from langchain.prompts.chat import ChatPromptValue
|
from langchain.prompts.chat import ChatPromptValue
|
||||||
@ -124,9 +124,10 @@ def get_prompts(
|
|||||||
missing_prompts = []
|
missing_prompts = []
|
||||||
missing_prompt_idxs = []
|
missing_prompt_idxs = []
|
||||||
existing_prompts = {}
|
existing_prompts = {}
|
||||||
|
llm_cache = get_llm_cache()
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
if langchain.llm_cache is not None:
|
if llm_cache is not None:
|
||||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||||
if isinstance(cache_val, list):
|
if isinstance(cache_val, list):
|
||||||
existing_prompts[i] = cache_val
|
existing_prompts[i] = cache_val
|
||||||
else:
|
else:
|
||||||
@ -143,11 +144,12 @@ def update_cache(
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
) -> Optional[dict]:
|
) -> Optional[dict]:
|
||||||
"""Update the cache and get the LLM output."""
|
"""Update the cache and get the LLM output."""
|
||||||
|
llm_cache = get_llm_cache()
|
||||||
for i, result in enumerate(new_results.generations):
|
for i, result in enumerate(new_results.generations):
|
||||||
existing_prompts[missing_prompt_idxs[i]] = result
|
existing_prompts[missing_prompt_idxs[i]] = result
|
||||||
prompt = prompts[missing_prompt_idxs[i]]
|
prompt = prompts[missing_prompt_idxs[i]]
|
||||||
if langchain.llm_cache is not None:
|
if llm_cache is not None:
|
||||||
langchain.llm_cache.update(prompt, llm_string, result)
|
llm_cache.update(prompt, llm_string, result)
|
||||||
llm_output = new_results.llm_output
|
llm_output = new_results.llm_output
|
||||||
return llm_output
|
return llm_output
|
||||||
|
|
||||||
@ -624,7 +626,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||||
"run_manager"
|
"run_manager"
|
||||||
)
|
)
|
||||||
if langchain.llm_cache is None or disregard_cache:
|
if get_llm_cache() is None or disregard_cache:
|
||||||
if self.cache is not None and self.cache:
|
if self.cache is not None and self.cache:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
@ -788,7 +790,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||||
"run_manager"
|
"run_manager"
|
||||||
)
|
)
|
||||||
if langchain.llm_cache is None or disregard_cache:
|
if get_llm_cache() is None or disregard_cache:
|
||||||
if self.cache is not None and self.cache:
|
if self.cache is not None and self.cache:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
|
@ -18,8 +18,8 @@ git grep '^from langchain' langchain/utilities | grep -vE 'from langchain.(pydan
|
|||||||
git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1))
|
git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1))
|
||||||
git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1))
|
git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1))
|
||||||
git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1))
|
git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1))
|
||||||
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities)' && errors=$((errors+1))
|
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities|globals)' && errors=$((errors+1))
|
||||||
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities)' && errors=$((errors+1))
|
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities|globals)' && errors=$((errors+1))
|
||||||
git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1))
|
git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1))
|
||||||
git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1))
|
git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1))
|
||||||
git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1))
|
git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1))
|
||||||
|
@ -5,8 +5,8 @@ from typing import Any, Iterator, Tuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.cache import CassandraCache, CassandraSemanticCache
|
from langchain.cache import CassandraCache, CassandraSemanticCache
|
||||||
|
from langchain.globals import get_llm_cache, set_llm_cache
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
@ -39,12 +39,12 @@ def cassandra_connection() -> Iterator[Tuple[Any, str]]:
|
|||||||
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
session, keyspace = cassandra_connection
|
session, keyspace = cassandra_connection
|
||||||
cache = CassandraCache(session=session, keyspace=keyspace)
|
cache = CassandraCache(session=session, keyspace=keyspace)
|
||||||
langchain.llm_cache = cache
|
set_llm_cache(cache)
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||||
output = llm.generate(["foo"])
|
output = llm.generate(["foo"])
|
||||||
print(output)
|
print(output)
|
||||||
expected_output = LLMResult(
|
expected_output = LLMResult(
|
||||||
@ -59,12 +59,12 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
|||||||
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
||||||
session, keyspace = cassandra_connection
|
session, keyspace = cassandra_connection
|
||||||
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
|
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
|
||||||
langchain.llm_cache = cache
|
set_llm_cache(cache)
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||||
expected_output = LLMResult(
|
expected_output = LLMResult(
|
||||||
generations=[[Generation(text="fizz")]],
|
generations=[[Generation(text="fizz")]],
|
||||||
llm_output={},
|
llm_output={},
|
||||||
@ -85,12 +85,12 @@ def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None
|
|||||||
keyspace=keyspace,
|
keyspace=keyspace,
|
||||||
embedding=FakeEmbeddings(),
|
embedding=FakeEmbeddings(),
|
||||||
)
|
)
|
||||||
langchain.llm_cache = sem_cache
|
set_llm_cache(sem_cache)
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||||
output = llm.generate(["bar"]) # same embedding as 'foo'
|
output = llm.generate(["bar"]) # same embedding as 'foo'
|
||||||
expected_output = LLMResult(
|
expected_output = LLMResult(
|
||||||
generations=[[Generation(text="fizz")]],
|
generations=[[Generation(text="fizz")]],
|
||||||
|
@ -3,8 +3,8 @@ from typing import Any, Callable, Union
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.cache import GPTCache
|
from langchain.cache import GPTCache
|
||||||
|
from langchain.globals import get_llm_cache, set_llm_cache
|
||||||
from langchain.schema import Generation
|
from langchain.schema import Generation
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
@ -48,15 +48,15 @@ def test_gptcache_caching(
|
|||||||
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None]
|
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test gptcache default caching behavior."""
|
"""Test gptcache default caching behavior."""
|
||||||
langchain.llm_cache = GPTCache(init_func)
|
set_llm_cache(GPTCache(init_func))
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||||
_ = llm.generate(["foo", "bar", "foo"])
|
_ = llm.generate(["foo", "bar", "foo"])
|
||||||
cache_output = langchain.llm_cache.lookup("foo", llm_string)
|
cache_output = get_llm_cache().lookup("foo", llm_string)
|
||||||
assert cache_output == [Generation(text="fizz")]
|
assert cache_output == [Generation(text="fizz")]
|
||||||
|
|
||||||
langchain.llm_cache.clear()
|
get_llm_cache().clear()
|
||||||
assert langchain.llm_cache.lookup("bar", llm_string) is None
|
assert get_llm_cache().lookup("bar", llm_string) is None
|
||||||
|
@ -12,8 +12,8 @@ from typing import Iterator
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.cache import MomentoCache
|
from langchain.cache import MomentoCache
|
||||||
|
from langchain.globals import set_llm_cache
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ def momento_cache() -> Iterator[MomentoCache]:
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
llm_cache = MomentoCache(client, cache_name)
|
llm_cache = MomentoCache(client, cache_name)
|
||||||
langchain.llm_cache = llm_cache
|
set_llm_cache(llm_cache)
|
||||||
yield llm_cache
|
yield llm_cache
|
||||||
finally:
|
finally:
|
||||||
client.delete_cache(cache_name)
|
client.delete_cache(cache_name)
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
"""Test Redis cache functionality."""
|
"""Test Redis cache functionality."""
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List
|
from typing import List, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.cache import RedisCache, RedisSemanticCache
|
from langchain.cache import RedisCache, RedisSemanticCache
|
||||||
|
from langchain.globals import get_llm_cache, set_llm_cache
|
||||||
from langchain.load.dump import dumps
|
from langchain.load.dump import dumps
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.schema.embeddings import Embeddings
|
from langchain.schema.embeddings import Embeddings
|
||||||
@ -28,40 +28,42 @@ def random_string() -> str:
|
|||||||
def test_redis_cache_ttl() -> None:
|
def test_redis_cache_ttl() -> None:
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1)
|
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1))
|
||||||
langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")])
|
llm_cache = cast(RedisCache, get_llm_cache())
|
||||||
key = langchain.llm_cache._key("foo", "bar")
|
llm_cache.update("foo", "bar", [Generation(text="fizz")])
|
||||||
assert langchain.llm_cache.redis.pttl(key) > 0
|
key = llm_cache._key("foo", "bar")
|
||||||
|
assert llm_cache.redis.pttl(key) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_redis_cache() -> None:
|
def test_redis_cache() -> None:
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)))
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||||
output = llm.generate(["foo"])
|
output = llm.generate(["foo"])
|
||||||
expected_output = LLMResult(
|
expected_output = LLMResult(
|
||||||
generations=[[Generation(text="fizz")]],
|
generations=[[Generation(text="fizz")]],
|
||||||
llm_output={},
|
llm_output={},
|
||||||
)
|
)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
langchain.llm_cache.redis.flushall()
|
llm_cache = cast(RedisCache, get_llm_cache())
|
||||||
|
llm_cache.redis.flushall()
|
||||||
|
|
||||||
|
|
||||||
def test_redis_cache_chat() -> None:
|
def test_redis_cache_chat() -> None:
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)))
|
||||||
llm = FakeChatModel()
|
llm = FakeChatModel()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||||
langchain.llm_cache.update(
|
get_llm_cache().update(
|
||||||
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
||||||
)
|
)
|
||||||
output = llm.generate([prompt])
|
output = llm.generate([prompt])
|
||||||
@ -70,18 +72,21 @@ def test_redis_cache_chat() -> None:
|
|||||||
llm_output={},
|
llm_output={},
|
||||||
)
|
)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
langchain.llm_cache.redis.flushall()
|
llm_cache = cast(RedisCache, get_llm_cache())
|
||||||
|
llm_cache.redis.flushall()
|
||||||
|
|
||||||
|
|
||||||
def test_redis_semantic_cache() -> None:
|
def test_redis_semantic_cache() -> None:
|
||||||
langchain.llm_cache = RedisSemanticCache(
|
set_llm_cache(
|
||||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
RedisSemanticCache(
|
||||||
|
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||||
|
)
|
||||||
)
|
)
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||||
output = llm.generate(
|
output = llm.generate(
|
||||||
["bar"]
|
["bar"]
|
||||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||||
@ -91,24 +96,26 @@ def test_redis_semantic_cache() -> None:
|
|||||||
)
|
)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
# clear the cache
|
# clear the cache
|
||||||
langchain.llm_cache.clear(llm_string=llm_string)
|
get_llm_cache().clear(llm_string=llm_string)
|
||||||
output = llm.generate(
|
output = llm.generate(
|
||||||
["bar"]
|
["bar"]
|
||||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||||
# expect different output now without cached result
|
# expect different output now without cached result
|
||||||
assert output != expected_output
|
assert output != expected_output
|
||||||
langchain.llm_cache.clear(llm_string=llm_string)
|
get_llm_cache().clear(llm_string=llm_string)
|
||||||
|
|
||||||
|
|
||||||
def test_redis_semantic_cache_multi() -> None:
|
def test_redis_semantic_cache_multi() -> None:
|
||||||
langchain.llm_cache = RedisSemanticCache(
|
set_llm_cache(
|
||||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
RedisSemanticCache(
|
||||||
|
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||||
|
)
|
||||||
)
|
)
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update(
|
get_llm_cache().update(
|
||||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||||
)
|
)
|
||||||
output = llm.generate(
|
output = llm.generate(
|
||||||
@ -120,19 +127,21 @@ def test_redis_semantic_cache_multi() -> None:
|
|||||||
)
|
)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
# clear the cache
|
# clear the cache
|
||||||
langchain.llm_cache.clear(llm_string=llm_string)
|
get_llm_cache().clear(llm_string=llm_string)
|
||||||
|
|
||||||
|
|
||||||
def test_redis_semantic_cache_chat() -> None:
|
def test_redis_semantic_cache_chat() -> None:
|
||||||
langchain.llm_cache = RedisSemanticCache(
|
set_llm_cache(
|
||||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
RedisSemanticCache(
|
||||||
|
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||||
|
)
|
||||||
)
|
)
|
||||||
llm = FakeChatModel()
|
llm = FakeChatModel()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||||
langchain.llm_cache.update(
|
get_llm_cache().update(
|
||||||
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
||||||
)
|
)
|
||||||
output = llm.generate([prompt])
|
output = llm.generate([prompt])
|
||||||
@ -141,7 +150,7 @@ def test_redis_semantic_cache_chat() -> None:
|
|||||||
llm_output={},
|
llm_output={},
|
||||||
)
|
)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
langchain.llm_cache.clear(llm_string=llm_string)
|
get_llm_cache().clear(llm_string=llm_string)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
|
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
|
||||||
@ -170,9 +179,7 @@ def test_redis_semantic_cache_chat() -> None:
|
|||||||
def test_redis_semantic_cache_hit(
|
def test_redis_semantic_cache_hit(
|
||||||
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
|
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
|
||||||
) -> None:
|
) -> None:
|
||||||
langchain.llm_cache = RedisSemanticCache(
|
set_llm_cache(RedisSemanticCache(embedding=embedding, redis_url=REDIS_TEST_URL))
|
||||||
embedding=embedding, redis_url=REDIS_TEST_URL
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
@ -189,7 +196,7 @@ def test_redis_semantic_cache_hit(
|
|||||||
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
||||||
print(prompt_i)
|
print(prompt_i)
|
||||||
print(llm_generations_i)
|
print(llm_generations_i)
|
||||||
langchain.llm_cache.update(prompt_i, llm_string, llm_generations_i)
|
get_llm_cache().update(prompt_i, llm_string, llm_generations_i)
|
||||||
llm.generate(prompts)
|
llm.generate(prompts)
|
||||||
assert llm.generate(prompts) == LLMResult(
|
assert llm.generate(prompts) == LLMResult(
|
||||||
generations=llm_generations, llm_output={}
|
generations=llm_generations, llm_output={}
|
||||||
|
@ -6,25 +6,25 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
||||||
|
from langchain.globals import get_llm_cache, set_llm_cache
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
def test_caching() -> None:
|
def test_caching() -> None:
|
||||||
"""Test caching behavior."""
|
"""Test caching behavior."""
|
||||||
langchain.llm_cache = InMemoryCache()
|
set_llm_cache(InMemoryCache())
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||||
output = llm.generate(["foo", "bar", "foo"])
|
output = llm.generate(["foo", "bar", "foo"])
|
||||||
expected_cache_output = [Generation(text="foo")]
|
expected_cache_output = [Generation(text="foo")]
|
||||||
cache_output = langchain.llm_cache.lookup("bar", llm_string)
|
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||||
assert cache_output == expected_cache_output
|
assert cache_output == expected_cache_output
|
||||||
langchain.llm_cache = None
|
set_llm_cache(None)
|
||||||
expected_generations = [
|
expected_generations = [
|
||||||
[Generation(text="fizz")],
|
[Generation(text="fizz")],
|
||||||
[Generation(text="foo")],
|
[Generation(text="foo")],
|
||||||
@ -52,17 +52,17 @@ def test_custom_caching() -> None:
|
|||||||
response = Column(String)
|
response = Column(String)
|
||||||
|
|
||||||
engine = create_engine("sqlite://")
|
engine = create_engine("sqlite://")
|
||||||
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
|
set_llm_cache(SQLAlchemyCache(engine, FulltextLLMCache))
|
||||||
llm = FakeLLM()
|
llm = FakeLLM()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||||
output = llm.generate(["foo", "bar", "foo"])
|
output = llm.generate(["foo", "bar", "foo"])
|
||||||
expected_cache_output = [Generation(text="foo")]
|
expected_cache_output = [Generation(text="foo")]
|
||||||
cache_output = langchain.llm_cache.lookup("bar", llm_string)
|
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||||
assert cache_output == expected_cache_output
|
assert cache_output == expected_cache_output
|
||||||
langchain.llm_cache = None
|
set_llm_cache(None)
|
||||||
expected_generations = [
|
expected_generations = [
|
||||||
[Generation(text="fizz")],
|
[Generation(text="fizz")],
|
||||||
[Generation(text="foo")],
|
[Generation(text="foo")],
|
||||||
|
@ -6,13 +6,13 @@ from _pytest.fixtures import FixtureRequest
|
|||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.cache import (
|
from langchain.cache import (
|
||||||
InMemoryCache,
|
InMemoryCache,
|
||||||
SQLAlchemyCache,
|
SQLAlchemyCache,
|
||||||
)
|
)
|
||||||
from langchain.chat_models import FakeListChatModel
|
from langchain.chat_models import FakeListChatModel
|
||||||
from langchain.chat_models.base import BaseChatModel, dumps
|
from langchain.chat_models.base import BaseChatModel, dumps
|
||||||
|
from langchain.globals import get_llm_cache, set_llm_cache
|
||||||
from langchain.llms import FakeListLLM
|
from langchain.llms import FakeListLLM
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
@ -36,18 +36,18 @@ CACHE_OPTIONS = [
|
|||||||
def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]:
|
def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]:
|
||||||
# Will be run before each test
|
# Will be run before each test
|
||||||
cache_instance = request.param
|
cache_instance = request.param
|
||||||
langchain.llm_cache = cache_instance()
|
set_llm_cache(cache_instance())
|
||||||
if langchain.llm_cache:
|
if get_llm_cache():
|
||||||
langchain.llm_cache.clear()
|
get_llm_cache().clear()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cache not set. This should never happen.")
|
raise ValueError("Cache not set. This should never happen.")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Will be run after each test
|
# Will be run after each test
|
||||||
if langchain.llm_cache:
|
if get_llm_cache():
|
||||||
langchain.llm_cache.clear()
|
get_llm_cache().clear()
|
||||||
langchain.llm_cache = None
|
set_llm_cache(None)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cache not set. This should never happen.")
|
raise ValueError("Cache not set. This should never happen.")
|
||||||
|
|
||||||
@ -57,8 +57,8 @@ def test_llm_caching() -> None:
|
|||||||
response = "Test response"
|
response = "Test response"
|
||||||
cached_response = "Cached test response"
|
cached_response = "Cached test response"
|
||||||
llm = FakeListLLM(responses=[response])
|
llm = FakeListLLM(responses=[response])
|
||||||
if langchain.llm_cache:
|
if get_llm_cache():
|
||||||
langchain.llm_cache.update(
|
get_llm_cache().update(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_string=create_llm_string(llm),
|
llm_string=create_llm_string(llm),
|
||||||
return_val=[Generation(text=cached_response)],
|
return_val=[Generation(text=cached_response)],
|
||||||
@ -72,20 +72,21 @@ def test_llm_caching() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_old_sqlite_llm_caching() -> None:
|
def test_old_sqlite_llm_caching() -> None:
|
||||||
if isinstance(langchain.llm_cache, SQLAlchemyCache):
|
llm_cache = get_llm_cache()
|
||||||
|
if isinstance(llm_cache, SQLAlchemyCache):
|
||||||
prompt = "How are you?"
|
prompt = "How are you?"
|
||||||
response = "Test response"
|
response = "Test response"
|
||||||
cached_response = "Cached test response"
|
cached_response = "Cached test response"
|
||||||
llm = FakeListLLM(responses=[response])
|
llm = FakeListLLM(responses=[response])
|
||||||
items = [
|
items = [
|
||||||
langchain.llm_cache.cache_schema(
|
llm_cache.cache_schema(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm=create_llm_string(llm),
|
llm=create_llm_string(llm),
|
||||||
response=cached_response,
|
response=cached_response,
|
||||||
idx=0,
|
idx=0,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
with Session(langchain.llm_cache.engine) as session, session.begin():
|
with Session(llm_cache.engine) as session, session.begin():
|
||||||
for item in items:
|
for item in items:
|
||||||
session.merge(item)
|
session.merge(item)
|
||||||
assert llm(prompt) == cached_response
|
assert llm(prompt) == cached_response
|
||||||
@ -97,8 +98,8 @@ def test_chat_model_caching() -> None:
|
|||||||
cached_response = "Cached test response"
|
cached_response = "Cached test response"
|
||||||
cached_message = AIMessage(content=cached_response)
|
cached_message = AIMessage(content=cached_response)
|
||||||
llm = FakeListChatModel(responses=[response])
|
llm = FakeListChatModel(responses=[response])
|
||||||
if langchain.llm_cache:
|
if get_llm_cache():
|
||||||
langchain.llm_cache.update(
|
get_llm_cache().update(
|
||||||
prompt=dumps(prompt),
|
prompt=dumps(prompt),
|
||||||
llm_string=llm._get_llm_string(),
|
llm_string=llm._get_llm_string(),
|
||||||
return_val=[ChatGeneration(message=cached_message)],
|
return_val=[ChatGeneration(message=cached_message)],
|
||||||
@ -119,8 +120,8 @@ def test_chat_model_caching_params() -> None:
|
|||||||
cached_response = "Cached test response"
|
cached_response = "Cached test response"
|
||||||
cached_message = AIMessage(content=cached_response)
|
cached_message = AIMessage(content=cached_response)
|
||||||
llm = FakeListChatModel(responses=[response])
|
llm = FakeListChatModel(responses=[response])
|
||||||
if langchain.llm_cache:
|
if get_llm_cache():
|
||||||
langchain.llm_cache.update(
|
get_llm_cache().update(
|
||||||
prompt=dumps(prompt),
|
prompt=dumps(prompt),
|
||||||
llm_string=llm._get_llm_string(functions=[]),
|
llm_string=llm._get_llm_string(functions=[]),
|
||||||
return_val=[ChatGeneration(message=cached_message)],
|
return_val=[ChatGeneration(message=cached_message)],
|
||||||
@ -144,13 +145,13 @@ def test_llm_cache_clear() -> None:
|
|||||||
response = "Test response"
|
response = "Test response"
|
||||||
cached_response = "Cached test response"
|
cached_response = "Cached test response"
|
||||||
llm = FakeListLLM(responses=[response])
|
llm = FakeListLLM(responses=[response])
|
||||||
if langchain.llm_cache:
|
if get_llm_cache():
|
||||||
langchain.llm_cache.update(
|
get_llm_cache().update(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_string=create_llm_string(llm),
|
llm_string=create_llm_string(llm),
|
||||||
return_val=[Generation(text=cached_response)],
|
return_val=[Generation(text=cached_response)],
|
||||||
)
|
)
|
||||||
langchain.llm_cache.clear()
|
get_llm_cache().clear()
|
||||||
assert llm(prompt) == response
|
assert llm(prompt) == response
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Loading…
Reference in New Issue
Block a user