mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-22 15:38:06 +00:00
use get_llm_cache and set_llm_cache (#11741)
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
f3ad22e64a
commit
4a2f0c51a1
@ -463,15 +463,15 @@ class RedisSemanticCache(BaseCache):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import langchain
|
||||
from langchain.globals import set_llm_cache
|
||||
|
||||
from langchain.cache import RedisSemanticCache
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
set_llm_cache(RedisSemanticCache(
|
||||
redis_url="redis://localhost:6379",
|
||||
embedding=OpenAIEmbeddings()
|
||||
)
|
||||
))
|
||||
|
||||
"""
|
||||
self._cache_dict: Dict[str, RedisVectorstore] = {}
|
||||
@ -588,6 +588,7 @@ class GPTCache(BaseCache):
|
||||
import gptcache
|
||||
from gptcache.processor.pre import get_prompt
|
||||
from gptcache.manager.factory import get_data_manager
|
||||
from langchain.globals import set_llm_cache
|
||||
|
||||
# Avoid multiple caches using the same file,
|
||||
causing different llm model caches to affect each other
|
||||
@ -601,7 +602,7 @@ class GPTCache(BaseCache):
|
||||
),
|
||||
)
|
||||
|
||||
langchain.llm_cache = GPTCache(init_gptcache)
|
||||
set_llm_cache(GPTCache(init_gptcache))
|
||||
|
||||
"""
|
||||
try:
|
||||
|
@ -15,7 +15,6 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
@ -24,6 +23,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.globals import get_llm_cache
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
@ -487,7 +487,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
llm_cache = get_llm_cache()
|
||||
if llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
@ -502,7 +503,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
@ -512,7 +513,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
)
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
langchain.llm_cache.update(prompt, llm_string, result.generations)
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
async def _agenerate_with_cache(
|
||||
@ -526,7 +527,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
llm_cache = get_llm_cache()
|
||||
if llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
@ -541,7 +543,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
@ -551,7 +553,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
)
|
||||
else:
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
langchain.llm_cache.update(prompt, llm_string, result.generations)
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
|
@ -121,7 +121,7 @@ def get_debug() -> bool:
|
||||
return _debug or old_debug
|
||||
|
||||
|
||||
def set_llm_cache(value: "BaseCache") -> None:
|
||||
def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
||||
"""Set a new LLM cache, overwriting the previous value, if any."""
|
||||
import langchain
|
||||
|
||||
|
@ -37,7 +37,6 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
@ -46,6 +45,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.globals import get_llm_cache
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
@ -124,9 +124,10 @@ def get_prompts(
|
||||
missing_prompts = []
|
||||
missing_prompt_idxs = []
|
||||
existing_prompts = {}
|
||||
llm_cache = get_llm_cache()
|
||||
for i, prompt in enumerate(prompts):
|
||||
if langchain.llm_cache is not None:
|
||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
||||
if llm_cache is not None:
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
existing_prompts[i] = cache_val
|
||||
else:
|
||||
@ -143,11 +144,12 @@ def update_cache(
|
||||
prompts: List[str],
|
||||
) -> Optional[dict]:
|
||||
"""Update the cache and get the LLM output."""
|
||||
llm_cache = get_llm_cache()
|
||||
for i, result in enumerate(new_results.generations):
|
||||
existing_prompts[missing_prompt_idxs[i]] = result
|
||||
prompt = prompts[missing_prompt_idxs[i]]
|
||||
if langchain.llm_cache is not None:
|
||||
langchain.llm_cache.update(prompt, llm_string, result)
|
||||
if llm_cache is not None:
|
||||
llm_cache.update(prompt, llm_string, result)
|
||||
llm_output = new_results.llm_output
|
||||
return llm_output
|
||||
|
||||
@ -624,7 +626,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
if get_llm_cache() is None or disregard_cache:
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
@ -788,7 +790,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
if get_llm_cache() is None or disregard_cache:
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
|
@ -18,8 +18,8 @@ git grep '^from langchain' langchain/utilities | grep -vE 'from langchain.(pydan
|
||||
git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities|globals)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities|globals)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1))
|
||||
|
@ -5,8 +5,8 @@ from typing import Any, Iterator, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
import langchain
|
||||
from langchain.cache import CassandraCache, CassandraSemanticCache
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
@ -39,12 +39,12 @@ def cassandra_connection() -> Iterator[Tuple[Any, str]]:
|
||||
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(session=session, keyspace=keyspace)
|
||||
langchain.llm_cache = cache
|
||||
set_llm_cache(cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
print(output)
|
||||
expected_output = LLMResult(
|
||||
@ -59,12 +59,12 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
|
||||
langchain.llm_cache = cache
|
||||
set_llm_cache(cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
@ -85,12 +85,12 @@ def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None
|
||||
keyspace=keyspace,
|
||||
embedding=FakeEmbeddings(),
|
||||
)
|
||||
langchain.llm_cache = sem_cache
|
||||
set_llm_cache(sem_cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["bar"]) # same embedding as 'foo'
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
|
@ -3,8 +3,8 @@ from typing import Any, Callable, Union
|
||||
|
||||
import pytest
|
||||
|
||||
import langchain
|
||||
from langchain.cache import GPTCache
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from langchain.schema import Generation
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
@ -48,15 +48,15 @@ def test_gptcache_caching(
|
||||
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None]
|
||||
) -> None:
|
||||
"""Test gptcache default caching behavior."""
|
||||
langchain.llm_cache = GPTCache(init_func)
|
||||
set_llm_cache(GPTCache(init_func))
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
_ = llm.generate(["foo", "bar", "foo"])
|
||||
cache_output = langchain.llm_cache.lookup("foo", llm_string)
|
||||
cache_output = get_llm_cache().lookup("foo", llm_string)
|
||||
assert cache_output == [Generation(text="fizz")]
|
||||
|
||||
langchain.llm_cache.clear()
|
||||
assert langchain.llm_cache.lookup("bar", llm_string) is None
|
||||
get_llm_cache().clear()
|
||||
assert get_llm_cache().lookup("bar", llm_string) is None
|
||||
|
@ -12,8 +12,8 @@ from typing import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
import langchain
|
||||
from langchain.cache import MomentoCache
|
||||
from langchain.globals import set_llm_cache
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
@ -34,7 +34,7 @@ def momento_cache() -> Iterator[MomentoCache]:
|
||||
)
|
||||
try:
|
||||
llm_cache = MomentoCache(client, cache_name)
|
||||
langchain.llm_cache = llm_cache
|
||||
set_llm_cache(llm_cache)
|
||||
yield llm_cache
|
||||
finally:
|
||||
client.delete_cache(cache_name)
|
||||
|
@ -1,11 +1,11 @@
|
||||
"""Test Redis cache functionality."""
|
||||
import uuid
|
||||
from typing import List
|
||||
from typing import List, cast
|
||||
|
||||
import pytest
|
||||
|
||||
import langchain
|
||||
from langchain.cache import RedisCache, RedisSemanticCache
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
@ -28,40 +28,42 @@ def random_string() -> str:
|
||||
def test_redis_cache_ttl() -> None:
|
||||
import redis
|
||||
|
||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1)
|
||||
langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")])
|
||||
key = langchain.llm_cache._key("foo", "bar")
|
||||
assert langchain.llm_cache.redis.pttl(key) > 0
|
||||
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1))
|
||||
llm_cache = cast(RedisCache, get_llm_cache())
|
||||
llm_cache.update("foo", "bar", [Generation(text="fizz")])
|
||||
key = llm_cache._key("foo", "bar")
|
||||
assert llm_cache.redis.pttl(key) > 0
|
||||
|
||||
|
||||
def test_redis_cache() -> None:
|
||||
import redis
|
||||
|
||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
||||
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)))
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
langchain.llm_cache.redis.flushall()
|
||||
llm_cache = cast(RedisCache, get_llm_cache())
|
||||
llm_cache.redis.flushall()
|
||||
|
||||
|
||||
def test_redis_cache_chat() -> None:
|
||||
import redis
|
||||
|
||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
||||
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)))
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||
langchain.llm_cache.update(
|
||||
get_llm_cache().update(
|
||||
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
||||
)
|
||||
output = llm.generate([prompt])
|
||||
@ -70,18 +72,21 @@ def test_redis_cache_chat() -> None:
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
langchain.llm_cache.redis.flushall()
|
||||
llm_cache = cast(RedisCache, get_llm_cache())
|
||||
llm_cache.redis.flushall()
|
||||
|
||||
|
||||
def test_redis_semantic_cache() -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
set_llm_cache(
|
||||
RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(
|
||||
["bar"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
@ -91,24 +96,26 @@ def test_redis_semantic_cache() -> None:
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
output = llm.generate(
|
||||
["bar"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
# expect different output now without cached result
|
||||
assert output != expected_output
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_multi() -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
set_llm_cache(
|
||||
RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update(
|
||||
get_llm_cache().update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
output = llm.generate(
|
||||
@ -120,19 +127,21 @@ def test_redis_semantic_cache_multi() -> None:
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_chat() -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
set_llm_cache(
|
||||
RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
)
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||
langchain.llm_cache.update(
|
||||
get_llm_cache().update(
|
||||
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
||||
)
|
||||
output = llm.generate([prompt])
|
||||
@ -141,7 +150,7 @@ def test_redis_semantic_cache_chat() -> None:
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
|
||||
@ -170,9 +179,7 @@ def test_redis_semantic_cache_chat() -> None:
|
||||
def test_redis_semantic_cache_hit(
|
||||
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
|
||||
) -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=embedding, redis_url=REDIS_TEST_URL
|
||||
)
|
||||
set_llm_cache(RedisSemanticCache(embedding=embedding, redis_url=REDIS_TEST_URL))
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
@ -189,7 +196,7 @@ def test_redis_semantic_cache_hit(
|
||||
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
||||
print(prompt_i)
|
||||
print(llm_generations_i)
|
||||
langchain.llm_cache.update(prompt_i, llm_string, llm_generations_i)
|
||||
get_llm_cache().update(prompt_i, llm_string, llm_generations_i)
|
||||
llm.generate(prompts)
|
||||
assert llm.generate(prompts) == LLMResult(
|
||||
generations=llm_generations, llm_output={}
|
||||
|
@ -6,25 +6,25 @@ try:
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
import langchain
|
||||
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def test_caching() -> None:
|
||||
"""Test caching behavior."""
|
||||
langchain.llm_cache = InMemoryCache()
|
||||
set_llm_cache(InMemoryCache())
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo", "bar", "foo"])
|
||||
expected_cache_output = [Generation(text="foo")]
|
||||
cache_output = langchain.llm_cache.lookup("bar", llm_string)
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == expected_cache_output
|
||||
langchain.llm_cache = None
|
||||
set_llm_cache(None)
|
||||
expected_generations = [
|
||||
[Generation(text="fizz")],
|
||||
[Generation(text="foo")],
|
||||
@ -52,17 +52,17 @@ def test_custom_caching() -> None:
|
||||
response = Column(String)
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
|
||||
set_llm_cache(SQLAlchemyCache(engine, FulltextLLMCache))
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo", "bar", "foo"])
|
||||
expected_cache_output = [Generation(text="foo")]
|
||||
cache_output = langchain.llm_cache.lookup("bar", llm_string)
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == expected_cache_output
|
||||
langchain.llm_cache = None
|
||||
set_llm_cache(None)
|
||||
expected_generations = [
|
||||
[Generation(text="fizz")],
|
||||
[Generation(text="foo")],
|
||||
|
@ -6,13 +6,13 @@ from _pytest.fixtures import FixtureRequest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import langchain
|
||||
from langchain.cache import (
|
||||
InMemoryCache,
|
||||
SQLAlchemyCache,
|
||||
)
|
||||
from langchain.chat_models import FakeListChatModel
|
||||
from langchain.chat_models.base import BaseChatModel, dumps
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from langchain.llms import FakeListLLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import (
|
||||
@ -36,18 +36,18 @@ CACHE_OPTIONS = [
|
||||
def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]:
|
||||
# Will be run before each test
|
||||
cache_instance = request.param
|
||||
langchain.llm_cache = cache_instance()
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.clear()
|
||||
set_llm_cache(cache_instance())
|
||||
if get_llm_cache():
|
||||
get_llm_cache().clear()
|
||||
else:
|
||||
raise ValueError("Cache not set. This should never happen.")
|
||||
|
||||
yield
|
||||
|
||||
# Will be run after each test
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.clear()
|
||||
langchain.llm_cache = None
|
||||
if get_llm_cache():
|
||||
get_llm_cache().clear()
|
||||
set_llm_cache(None)
|
||||
else:
|
||||
raise ValueError("Cache not set. This should never happen.")
|
||||
|
||||
@ -57,8 +57,8 @@ def test_llm_caching() -> None:
|
||||
response = "Test response"
|
||||
cached_response = "Cached test response"
|
||||
llm = FakeListLLM(responses=[response])
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.update(
|
||||
if get_llm_cache():
|
||||
get_llm_cache().update(
|
||||
prompt=prompt,
|
||||
llm_string=create_llm_string(llm),
|
||||
return_val=[Generation(text=cached_response)],
|
||||
@ -72,20 +72,21 @@ def test_llm_caching() -> None:
|
||||
|
||||
|
||||
def test_old_sqlite_llm_caching() -> None:
|
||||
if isinstance(langchain.llm_cache, SQLAlchemyCache):
|
||||
llm_cache = get_llm_cache()
|
||||
if isinstance(llm_cache, SQLAlchemyCache):
|
||||
prompt = "How are you?"
|
||||
response = "Test response"
|
||||
cached_response = "Cached test response"
|
||||
llm = FakeListLLM(responses=[response])
|
||||
items = [
|
||||
langchain.llm_cache.cache_schema(
|
||||
llm_cache.cache_schema(
|
||||
prompt=prompt,
|
||||
llm=create_llm_string(llm),
|
||||
response=cached_response,
|
||||
idx=0,
|
||||
)
|
||||
]
|
||||
with Session(langchain.llm_cache.engine) as session, session.begin():
|
||||
with Session(llm_cache.engine) as session, session.begin():
|
||||
for item in items:
|
||||
session.merge(item)
|
||||
assert llm(prompt) == cached_response
|
||||
@ -97,8 +98,8 @@ def test_chat_model_caching() -> None:
|
||||
cached_response = "Cached test response"
|
||||
cached_message = AIMessage(content=cached_response)
|
||||
llm = FakeListChatModel(responses=[response])
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.update(
|
||||
if get_llm_cache():
|
||||
get_llm_cache().update(
|
||||
prompt=dumps(prompt),
|
||||
llm_string=llm._get_llm_string(),
|
||||
return_val=[ChatGeneration(message=cached_message)],
|
||||
@ -119,8 +120,8 @@ def test_chat_model_caching_params() -> None:
|
||||
cached_response = "Cached test response"
|
||||
cached_message = AIMessage(content=cached_response)
|
||||
llm = FakeListChatModel(responses=[response])
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.update(
|
||||
if get_llm_cache():
|
||||
get_llm_cache().update(
|
||||
prompt=dumps(prompt),
|
||||
llm_string=llm._get_llm_string(functions=[]),
|
||||
return_val=[ChatGeneration(message=cached_message)],
|
||||
@ -144,13 +145,13 @@ def test_llm_cache_clear() -> None:
|
||||
response = "Test response"
|
||||
cached_response = "Cached test response"
|
||||
llm = FakeListLLM(responses=[response])
|
||||
if langchain.llm_cache:
|
||||
langchain.llm_cache.update(
|
||||
if get_llm_cache():
|
||||
get_llm_cache().update(
|
||||
prompt=prompt,
|
||||
llm_string=create_llm_string(llm),
|
||||
return_val=[Generation(text=cached_response)],
|
||||
)
|
||||
langchain.llm_cache.clear()
|
||||
get_llm_cache().clear()
|
||||
assert llm(prompt) == response
|
||||
else:
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user