use get_llm_cache and set_llm_cache (#11741)

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Harrison Chase 2023-10-14 09:29:30 -07:00 committed by GitHub
parent f3ad22e64a
commit 4a2f0c51a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 106 additions and 93 deletions

View File

@ -463,15 +463,15 @@ class RedisSemanticCache(BaseCache):
.. code-block:: python .. code-block:: python
import langchain from langchain.globals import set_llm_cache
from langchain.cache import RedisSemanticCache from langchain.cache import RedisSemanticCache
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
langchain.llm_cache = RedisSemanticCache( set_llm_cache(RedisSemanticCache(
redis_url="redis://localhost:6379", redis_url="redis://localhost:6379",
embedding=OpenAIEmbeddings() embedding=OpenAIEmbeddings()
) ))
""" """
self._cache_dict: Dict[str, RedisVectorstore] = {} self._cache_dict: Dict[str, RedisVectorstore] = {}
@ -588,6 +588,7 @@ class GPTCache(BaseCache):
import gptcache import gptcache
from gptcache.processor.pre import get_prompt from gptcache.processor.pre import get_prompt
from gptcache.manager.factory import get_data_manager from gptcache.manager.factory import get_data_manager
from langchain.globals import set_llm_cache
# Avoid multiple caches using the same file, # Avoid multiple caches using the same file,
causing different llm model caches to affect each other causing different llm model caches to affect each other
@ -601,7 +602,7 @@ class GPTCache(BaseCache):
), ),
) )
langchain.llm_cache = GPTCache(init_gptcache) set_llm_cache(GPTCache(init_gptcache))
""" """
try: try:

View File

@ -15,7 +15,6 @@ from typing import (
cast, cast,
) )
import langchain
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManager, AsyncCallbackManager,
@ -24,6 +23,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
Callbacks, Callbacks,
) )
from langchain.globals import get_llm_cache
from langchain.load.dump import dumpd, dumps from langchain.load.dump import dumpd, dumps
from langchain.prompts.base import StringPromptValue from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue from langchain.prompts.chat import ChatPromptValue
@ -487,7 +487,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
"run_manager" "run_manager"
) )
disregard_cache = self.cache is not None and not self.cache disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache: llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True # This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache: if self.cache is not None and self.cache:
raise ValueError( raise ValueError(
@ -502,7 +503,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
else: else:
llm_string = self._get_llm_string(stop=stop, **kwargs) llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages) prompt = dumps(messages)
cache_val = langchain.llm_cache.lookup(prompt, llm_string) cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list): if isinstance(cache_val, list):
return ChatResult(generations=cache_val) return ChatResult(generations=cache_val)
else: else:
@ -512,7 +513,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) )
else: else:
result = self._generate(messages, stop=stop, **kwargs) result = self._generate(messages, stop=stop, **kwargs)
langchain.llm_cache.update(prompt, llm_string, result.generations) llm_cache.update(prompt, llm_string, result.generations)
return result return result
async def _agenerate_with_cache( async def _agenerate_with_cache(
@ -526,7 +527,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
"run_manager" "run_manager"
) )
disregard_cache = self.cache is not None and not self.cache disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache: llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True # This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache: if self.cache is not None and self.cache:
raise ValueError( raise ValueError(
@ -541,7 +543,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
else: else:
llm_string = self._get_llm_string(stop=stop, **kwargs) llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages) prompt = dumps(messages)
cache_val = langchain.llm_cache.lookup(prompt, llm_string) cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list): if isinstance(cache_val, list):
return ChatResult(generations=cache_val) return ChatResult(generations=cache_val)
else: else:
@ -551,7 +553,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) )
else: else:
result = await self._agenerate(messages, stop=stop, **kwargs) result = await self._agenerate(messages, stop=stop, **kwargs)
langchain.llm_cache.update(prompt, llm_string, result.generations) llm_cache.update(prompt, llm_string, result.generations)
return result return result
@abstractmethod @abstractmethod

View File

@ -121,7 +121,7 @@ def get_debug() -> bool:
return _debug or old_debug return _debug or old_debug
def set_llm_cache(value: "BaseCache") -> None: def set_llm_cache(value: Optional["BaseCache"]) -> None:
"""Set a new LLM cache, overwriting the previous value, if any.""" """Set a new LLM cache, overwriting the previous value, if any."""
import langchain import langchain

View File

@ -37,7 +37,6 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
import langchain
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManager, AsyncCallbackManager,
@ -46,6 +45,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
Callbacks, Callbacks,
) )
from langchain.globals import get_llm_cache
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.prompts.base import StringPromptValue from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue from langchain.prompts.chat import ChatPromptValue
@ -124,9 +124,10 @@ def get_prompts(
missing_prompts = [] missing_prompts = []
missing_prompt_idxs = [] missing_prompt_idxs = []
existing_prompts = {} existing_prompts = {}
llm_cache = get_llm_cache()
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if langchain.llm_cache is not None: if llm_cache is not None:
cache_val = langchain.llm_cache.lookup(prompt, llm_string) cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list): if isinstance(cache_val, list):
existing_prompts[i] = cache_val existing_prompts[i] = cache_val
else: else:
@ -143,11 +144,12 @@ def update_cache(
prompts: List[str], prompts: List[str],
) -> Optional[dict]: ) -> Optional[dict]:
"""Update the cache and get the LLM output.""" """Update the cache and get the LLM output."""
llm_cache = get_llm_cache()
for i, result in enumerate(new_results.generations): for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]] prompt = prompts[missing_prompt_idxs[i]]
if langchain.llm_cache is not None: if llm_cache is not None:
langchain.llm_cache.update(prompt, llm_string, result) llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output llm_output = new_results.llm_output
return llm_output return llm_output
@ -624,7 +626,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
new_arg_supported = inspect.signature(self._generate).parameters.get( new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager" "run_manager"
) )
if langchain.llm_cache is None or disregard_cache: if get_llm_cache() is None or disregard_cache:
if self.cache is not None and self.cache: if self.cache is not None and self.cache:
raise ValueError( raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
@ -788,7 +790,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
new_arg_supported = inspect.signature(self._agenerate).parameters.get( new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager" "run_manager"
) )
if langchain.llm_cache is None or disregard_cache: if get_llm_cache() is None or disregard_cache:
if self.cache is not None and self.cache: if self.cache is not None and self.cache:
raise ValueError( raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."

View File

@ -18,8 +18,8 @@ git grep '^from langchain' langchain/utilities | grep -vE 'from langchain.(pydan
git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1)) git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1)) git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1))
git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1)) git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1))
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities)' && errors=$((errors+1)) git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities|globals)' && errors=$((errors+1))
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities)' && errors=$((errors+1)) git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities|globals)' && errors=$((errors+1))
git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1)) git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1)) git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1))
git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1)) git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1))

View File

@ -5,8 +5,8 @@ from typing import Any, Iterator, Tuple
import pytest import pytest
import langchain
from langchain.cache import CassandraCache, CassandraSemanticCache from langchain.cache import CassandraCache, CassandraSemanticCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
@ -39,12 +39,12 @@ def cassandra_connection() -> Iterator[Tuple[Any, str]]:
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None: def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection session, keyspace = cassandra_connection
cache = CassandraCache(session=session, keyspace=keyspace) cache = CassandraCache(session=session, keyspace=keyspace)
langchain.llm_cache = cache set_llm_cache(cache)
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"]) output = llm.generate(["foo"])
print(output) print(output)
expected_output = LLMResult( expected_output = LLMResult(
@ -59,12 +59,12 @@ def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None: def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection session, keyspace = cassandra_connection
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2) cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
langchain.llm_cache = cache set_llm_cache(cache)
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
expected_output = LLMResult( expected_output = LLMResult(
generations=[[Generation(text="fizz")]], generations=[[Generation(text="fizz")]],
llm_output={}, llm_output={},
@ -85,12 +85,12 @@ def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None
keyspace=keyspace, keyspace=keyspace,
embedding=FakeEmbeddings(), embedding=FakeEmbeddings(),
) )
langchain.llm_cache = sem_cache set_llm_cache(sem_cache)
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["bar"]) # same embedding as 'foo' output = llm.generate(["bar"]) # same embedding as 'foo'
expected_output = LLMResult( expected_output = LLMResult(
generations=[[Generation(text="fizz")]], generations=[[Generation(text="fizz")]],

View File

@ -3,8 +3,8 @@ from typing import Any, Callable, Union
import pytest import pytest
import langchain
from langchain.cache import GPTCache from langchain.cache import GPTCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.schema import Generation from langchain.schema import Generation
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
@ -48,15 +48,15 @@ def test_gptcache_caching(
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None] init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None]
) -> None: ) -> None:
"""Test gptcache default caching behavior.""" """Test gptcache default caching behavior."""
langchain.llm_cache = GPTCache(init_func) set_llm_cache(GPTCache(init_func))
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
_ = llm.generate(["foo", "bar", "foo"]) _ = llm.generate(["foo", "bar", "foo"])
cache_output = langchain.llm_cache.lookup("foo", llm_string) cache_output = get_llm_cache().lookup("foo", llm_string)
assert cache_output == [Generation(text="fizz")] assert cache_output == [Generation(text="fizz")]
langchain.llm_cache.clear() get_llm_cache().clear()
assert langchain.llm_cache.lookup("bar", llm_string) is None assert get_llm_cache().lookup("bar", llm_string) is None

View File

@ -12,8 +12,8 @@ from typing import Iterator
import pytest import pytest
import langchain
from langchain.cache import MomentoCache from langchain.cache import MomentoCache
from langchain.globals import set_llm_cache
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
@ -34,7 +34,7 @@ def momento_cache() -> Iterator[MomentoCache]:
) )
try: try:
llm_cache = MomentoCache(client, cache_name) llm_cache = MomentoCache(client, cache_name)
langchain.llm_cache = llm_cache set_llm_cache(llm_cache)
yield llm_cache yield llm_cache
finally: finally:
client.delete_cache(cache_name) client.delete_cache(cache_name)

View File

@ -1,11 +1,11 @@
"""Test Redis cache functionality.""" """Test Redis cache functionality."""
import uuid import uuid
from typing import List from typing import List, cast
import pytest import pytest
import langchain
from langchain.cache import RedisCache, RedisSemanticCache from langchain.cache import RedisCache, RedisSemanticCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.load.dump import dumps from langchain.load.dump import dumps
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.schema.embeddings import Embeddings from langchain.schema.embeddings import Embeddings
@ -28,40 +28,42 @@ def random_string() -> str:
def test_redis_cache_ttl() -> None: def test_redis_cache_ttl() -> None:
import redis import redis
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1) set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1))
langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")]) llm_cache = cast(RedisCache, get_llm_cache())
key = langchain.llm_cache._key("foo", "bar") llm_cache.update("foo", "bar", [Generation(text="fizz")])
assert langchain.llm_cache.redis.pttl(key) > 0 key = llm_cache._key("foo", "bar")
assert llm_cache.redis.pttl(key) > 0
def test_redis_cache() -> None: def test_redis_cache() -> None:
import redis import redis
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)) set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)))
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"]) output = llm.generate(["foo"])
expected_output = LLMResult( expected_output = LLMResult(
generations=[[Generation(text="fizz")]], generations=[[Generation(text="fizz")]],
llm_output={}, llm_output={},
) )
assert output == expected_output assert output == expected_output
langchain.llm_cache.redis.flushall() llm_cache = cast(RedisCache, get_llm_cache())
llm_cache.redis.flushall()
def test_redis_cache_chat() -> None: def test_redis_cache_chat() -> None:
import redis import redis
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)) set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)))
llm = FakeChatModel() llm = FakeChatModel()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")] prompt: List[BaseMessage] = [HumanMessage(content="foo")]
langchain.llm_cache.update( get_llm_cache().update(
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))] dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
) )
output = llm.generate([prompt]) output = llm.generate([prompt])
@ -70,18 +72,21 @@ def test_redis_cache_chat() -> None:
llm_output={}, llm_output={},
) )
assert output == expected_output assert output == expected_output
langchain.llm_cache.redis.flushall() llm_cache = cast(RedisCache, get_llm_cache())
llm_cache.redis.flushall()
def test_redis_semantic_cache() -> None: def test_redis_semantic_cache() -> None:
langchain.llm_cache = RedisSemanticCache( set_llm_cache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1 RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
) )
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate( output = llm.generate(
["bar"] ["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings ) # foo and bar will have the same embedding produced by FakeEmbeddings
@ -91,24 +96,26 @@ def test_redis_semantic_cache() -> None:
) )
assert output == expected_output assert output == expected_output
# clear the cache # clear the cache
langchain.llm_cache.clear(llm_string=llm_string) get_llm_cache().clear(llm_string=llm_string)
output = llm.generate( output = llm.generate(
["bar"] ["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings ) # foo and bar will have the same embedding produced by FakeEmbeddings
# expect different output now without cached result # expect different output now without cached result
assert output != expected_output assert output != expected_output
langchain.llm_cache.clear(llm_string=llm_string) get_llm_cache().clear(llm_string=llm_string)
def test_redis_semantic_cache_multi() -> None: def test_redis_semantic_cache_multi() -> None:
langchain.llm_cache = RedisSemanticCache( set_llm_cache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1 RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
) )
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update( get_llm_cache().update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")] "foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
) )
output = llm.generate( output = llm.generate(
@ -120,19 +127,21 @@ def test_redis_semantic_cache_multi() -> None:
) )
assert output == expected_output assert output == expected_output
# clear the cache # clear the cache
langchain.llm_cache.clear(llm_string=llm_string) get_llm_cache().clear(llm_string=llm_string)
def test_redis_semantic_cache_chat() -> None: def test_redis_semantic_cache_chat() -> None:
langchain.llm_cache = RedisSemanticCache( set_llm_cache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1 RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
) )
llm = FakeChatModel() llm = FakeChatModel()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")] prompt: List[BaseMessage] = [HumanMessage(content="foo")]
langchain.llm_cache.update( get_llm_cache().update(
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))] dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
) )
output = llm.generate([prompt]) output = llm.generate([prompt])
@ -141,7 +150,7 @@ def test_redis_semantic_cache_chat() -> None:
llm_output={}, llm_output={},
) )
assert output == expected_output assert output == expected_output
langchain.llm_cache.clear(llm_string=llm_string) get_llm_cache().clear(llm_string=llm_string)
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()]) @pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
@ -170,9 +179,7 @@ def test_redis_semantic_cache_chat() -> None:
def test_redis_semantic_cache_hit( def test_redis_semantic_cache_hit(
embedding: Embeddings, prompts: List[str], generations: List[List[str]] embedding: Embeddings, prompts: List[str], generations: List[List[str]]
) -> None: ) -> None:
langchain.llm_cache = RedisSemanticCache( set_llm_cache(RedisSemanticCache(embedding=embedding, redis_url=REDIS_TEST_URL))
embedding=embedding, redis_url=REDIS_TEST_URL
)
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
@ -189,7 +196,7 @@ def test_redis_semantic_cache_hit(
for prompt_i, llm_generations_i in zip(prompts, llm_generations): for prompt_i, llm_generations_i in zip(prompts, llm_generations):
print(prompt_i) print(prompt_i)
print(llm_generations_i) print(llm_generations_i)
langchain.llm_cache.update(prompt_i, llm_string, llm_generations_i) get_llm_cache().update(prompt_i, llm_string, llm_generations_i)
llm.generate(prompts) llm.generate(prompts)
assert llm.generate(prompts) == LLMResult( assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={} generations=llm_generations, llm_output={}

View File

@ -6,25 +6,25 @@ try:
except ImportError: except ImportError:
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
import langchain
from langchain.cache import InMemoryCache, SQLAlchemyCache from langchain.cache import InMemoryCache, SQLAlchemyCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
def test_caching() -> None: def test_caching() -> None:
"""Test caching behavior.""" """Test caching behavior."""
langchain.llm_cache = InMemoryCache() set_llm_cache(InMemoryCache())
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"]) output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")] expected_cache_output = [Generation(text="foo")]
cache_output = langchain.llm_cache.lookup("bar", llm_string) cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == expected_cache_output assert cache_output == expected_cache_output
langchain.llm_cache = None set_llm_cache(None)
expected_generations = [ expected_generations = [
[Generation(text="fizz")], [Generation(text="fizz")],
[Generation(text="foo")], [Generation(text="foo")],
@ -52,17 +52,17 @@ def test_custom_caching() -> None:
response = Column(String) response = Column(String)
engine = create_engine("sqlite://") engine = create_engine("sqlite://")
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache) set_llm_cache(SQLAlchemyCache(engine, FulltextLLMCache))
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"]) output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")] expected_cache_output = [Generation(text="foo")]
cache_output = langchain.llm_cache.lookup("bar", llm_string) cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == expected_cache_output assert cache_output == expected_cache_output
langchain.llm_cache = None set_llm_cache(None)
expected_generations = [ expected_generations = [
[Generation(text="fizz")], [Generation(text="fizz")],
[Generation(text="foo")], [Generation(text="foo")],

View File

@ -6,13 +6,13 @@ from _pytest.fixtures import FixtureRequest
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
import langchain
from langchain.cache import ( from langchain.cache import (
InMemoryCache, InMemoryCache,
SQLAlchemyCache, SQLAlchemyCache,
) )
from langchain.chat_models import FakeListChatModel from langchain.chat_models import FakeListChatModel
from langchain.chat_models.base import BaseChatModel, dumps from langchain.chat_models.base import BaseChatModel, dumps
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.llms import FakeListLLM from langchain.llms import FakeListLLM
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.schema import ( from langchain.schema import (
@ -36,18 +36,18 @@ CACHE_OPTIONS = [
def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]: def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, None]:
# Will be run before each test # Will be run before each test
cache_instance = request.param cache_instance = request.param
langchain.llm_cache = cache_instance() set_llm_cache(cache_instance())
if langchain.llm_cache: if get_llm_cache():
langchain.llm_cache.clear() get_llm_cache().clear()
else: else:
raise ValueError("Cache not set. This should never happen.") raise ValueError("Cache not set. This should never happen.")
yield yield
# Will be run after each test # Will be run after each test
if langchain.llm_cache: if get_llm_cache():
langchain.llm_cache.clear() get_llm_cache().clear()
langchain.llm_cache = None set_llm_cache(None)
else: else:
raise ValueError("Cache not set. This should never happen.") raise ValueError("Cache not set. This should never happen.")
@ -57,8 +57,8 @@ def test_llm_caching() -> None:
response = "Test response" response = "Test response"
cached_response = "Cached test response" cached_response = "Cached test response"
llm = FakeListLLM(responses=[response]) llm = FakeListLLM(responses=[response])
if langchain.llm_cache: if get_llm_cache():
langchain.llm_cache.update( get_llm_cache().update(
prompt=prompt, prompt=prompt,
llm_string=create_llm_string(llm), llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)], return_val=[Generation(text=cached_response)],
@ -72,20 +72,21 @@ def test_llm_caching() -> None:
def test_old_sqlite_llm_caching() -> None: def test_old_sqlite_llm_caching() -> None:
if isinstance(langchain.llm_cache, SQLAlchemyCache): llm_cache = get_llm_cache()
if isinstance(llm_cache, SQLAlchemyCache):
prompt = "How are you?" prompt = "How are you?"
response = "Test response" response = "Test response"
cached_response = "Cached test response" cached_response = "Cached test response"
llm = FakeListLLM(responses=[response]) llm = FakeListLLM(responses=[response])
items = [ items = [
langchain.llm_cache.cache_schema( llm_cache.cache_schema(
prompt=prompt, prompt=prompt,
llm=create_llm_string(llm), llm=create_llm_string(llm),
response=cached_response, response=cached_response,
idx=0, idx=0,
) )
] ]
with Session(langchain.llm_cache.engine) as session, session.begin(): with Session(llm_cache.engine) as session, session.begin():
for item in items: for item in items:
session.merge(item) session.merge(item)
assert llm(prompt) == cached_response assert llm(prompt) == cached_response
@ -97,8 +98,8 @@ def test_chat_model_caching() -> None:
cached_response = "Cached test response" cached_response = "Cached test response"
cached_message = AIMessage(content=cached_response) cached_message = AIMessage(content=cached_response)
llm = FakeListChatModel(responses=[response]) llm = FakeListChatModel(responses=[response])
if langchain.llm_cache: if get_llm_cache():
langchain.llm_cache.update( get_llm_cache().update(
prompt=dumps(prompt), prompt=dumps(prompt),
llm_string=llm._get_llm_string(), llm_string=llm._get_llm_string(),
return_val=[ChatGeneration(message=cached_message)], return_val=[ChatGeneration(message=cached_message)],
@ -119,8 +120,8 @@ def test_chat_model_caching_params() -> None:
cached_response = "Cached test response" cached_response = "Cached test response"
cached_message = AIMessage(content=cached_response) cached_message = AIMessage(content=cached_response)
llm = FakeListChatModel(responses=[response]) llm = FakeListChatModel(responses=[response])
if langchain.llm_cache: if get_llm_cache():
langchain.llm_cache.update( get_llm_cache().update(
prompt=dumps(prompt), prompt=dumps(prompt),
llm_string=llm._get_llm_string(functions=[]), llm_string=llm._get_llm_string(functions=[]),
return_val=[ChatGeneration(message=cached_message)], return_val=[ChatGeneration(message=cached_message)],
@ -144,13 +145,13 @@ def test_llm_cache_clear() -> None:
response = "Test response" response = "Test response"
cached_response = "Cached test response" cached_response = "Cached test response"
llm = FakeListLLM(responses=[response]) llm = FakeListLLM(responses=[response])
if langchain.llm_cache: if get_llm_cache():
langchain.llm_cache.update( get_llm_cache().update(
prompt=prompt, prompt=prompt,
llm_string=create_llm_string(llm), llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)], return_val=[Generation(text=cached_response)],
) )
langchain.llm_cache.clear() get_llm_cache().clear()
assert llm(prompt) == response assert llm(prompt) == response
else: else:
raise ValueError( raise ValueError(