From f92738a6f6c37909d42cda9e0493e5ad05b9af6e Mon Sep 17 00:00:00 2001 From: Dmitry Kankalovich Date: Thu, 8 Feb 2024 04:06:09 +0100 Subject: [PATCH] langchain[minor], community[minor], core[minor]: Async Cache support and AsyncRedisCache (#15817) * This PR adds async methods to the LLM cache. * Adds an implementation using Redis called AsyncRedisCache. * Adds a docker compose file at the /docker to help spin up docker * Updates redis tests to use a context manager so flushing always happens by default --- docker/docker-compose.yml | 17 ++ libs/community/langchain_community/cache.py | 214 ++++++++++++----- libs/core/langchain_core/caches.py | 15 ++ .../language_models/chat_models.py | 4 +- .../langchain_core/language_models/llms.py | 42 +++- libs/langchain/langchain/cache.py | 2 + .../cache/test_redis_cache.py | 218 ++++++++++++++---- libs/langchain/tests/unit_tests/test_cache.py | 93 ++++++-- 8 files changed, 472 insertions(+), 133 deletions(-) create mode 100644 docker/docker-compose.yml diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 00000000000..ce680ccafda --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,17 @@ +# docker-compose to make it easier to spin up integration tests. +# Services should use NON standard ports to avoid collision with +version: "3" +name: langchain-tests + +services: + redis: + image: redis/redis-stack-server:latest + # We use non standard ports since + # these instances are used for testing + # and users may already have existing + # redis instances set up locally + # for other projects + ports: + - "6020:6379" + volumes: + - ./redis-volume:/data diff --git a/libs/community/langchain_community/cache.py b/libs/community/langchain_community/cache.py index b2001682942..54996071f86 100644 --- a/libs/community/langchain_community/cache.py +++ b/libs/community/langchain_community/cache.py @@ -27,6 +27,7 @@ import json import logging import uuid import warnings +from abc import ABC from datetime import timedelta from functools import lru_cache from typing import ( @@ -351,49 +352,26 @@ class UpstashRedisCache(BaseCache): self.redis.flushdb(flush_type=asynchronous) -class RedisCache(BaseCache): - """Cache that uses Redis as a backend.""" - - def __init__(self, redis_: Any, *, ttl: Optional[int] = None): - """ - Initialize an instance of RedisCache. - - This method initializes an object with Redis caching capabilities. - It takes a `redis_` parameter, which should be an instance of a Redis - client class, allowing the object to interact with a Redis - server for caching purposes. - - Parameters: - redis_ (Any): An instance of a Redis client class - (e.g., redis.Redis) used for caching. - This allows the object to communicate with a - Redis server for caching operations. - ttl (int, optional): Time-to-live (TTL) for cached items in seconds. - If provided, it sets the time duration for how long cached - items will remain valid. If not provided, cached items will not - have an automatic expiration. - """ - try: - from redis import Redis - except ImportError: - raise ValueError( - "Could not import redis python package. " - "Please install it with `pip install redis`." - ) - if not isinstance(redis_, Redis): - raise ValueError("Please pass in Redis object.") - self.redis = redis_ - self.ttl = ttl - - def _key(self, prompt: str, llm_string: str) -> str: +class _RedisCacheBase(BaseCache, ABC): + @staticmethod + def _key(prompt: str, llm_string: str) -> str: """Compute key from prompt and llm_string""" return _hash(prompt + llm_string) - def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: - """Look up based on prompt and llm_string.""" + @staticmethod + def _ensure_generation_type(return_val: RETURN_VAL_TYPE) -> None: + for gen in return_val: + if not isinstance(gen, Generation): + raise ValueError( + "RedisCache only supports caching of normal LLM generations, " + f"got {type(gen)}" + ) + + @staticmethod + def _get_generations( + results: dict[str | bytes, str | bytes], + ) -> Optional[List[Generation]]: generations = [] - # Read from a Redis HASH - results = self.redis.hgetall(self._key(prompt, llm_string)) if results: for _, text in results.items(): try: @@ -410,28 +388,69 @@ class RedisCache(BaseCache): generations.append(Generation(text=text)) return generations if generations else None + @staticmethod + def _configure_pipeline_for_update( + key: str, pipe: Any, return_val: RETURN_VAL_TYPE, ttl: Optional[int] = None + ) -> None: + pipe.hset( + key, + mapping={ + str(idx): dumps(generation) for idx, generation in enumerate(return_val) + }, + ) + if ttl is not None: + pipe.expire(key, ttl) + + +class RedisCache(_RedisCacheBase): + """ + Cache that uses Redis as a backend. Allows to use a sync `redis.Redis` client. + """ + + def __init__(self, redis_: Any, *, ttl: Optional[int] = None): + """ + Initialize an instance of RedisCache. + + This method initializes an object with Redis caching capabilities. + It takes a `redis_` parameter, which should be an instance of a Redis + client class (`redis.Redis`), allowing the object + to interact with a Redis server for caching purposes. + + Parameters: + redis_ (Any): An instance of a Redis client class + (`redis.Redis`) to be used for caching. + This allows the object to communicate with a + Redis server for caching operations. + ttl (int, optional): Time-to-live (TTL) for cached items in seconds. + If provided, it sets the time duration for how long cached + items will remain valid. If not provided, cached items will not + have an automatic expiration. + """ + try: + from redis import Redis + except ImportError: + raise ValueError( + "Could not import `redis` python package. " + "Please install it with `pip install redis`." + ) + if not isinstance(redis_, Redis): + raise ValueError("Please pass a valid `redis.Redis` client.") + self.redis = redis_ + self.ttl = ttl + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + # Read from a Redis HASH + results = self.redis.hgetall(self._key(prompt, llm_string)) + return self._get_generations(results) # type: ignore[arg-type] + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update cache based on prompt and llm_string.""" - for gen in return_val: - if not isinstance(gen, Generation): - raise ValueError( - "RedisCache only supports caching of normal LLM generations, " - f"got {type(gen)}" - ) - # Write to a Redis HASH + self._ensure_generation_type(return_val) key = self._key(prompt, llm_string) with self.redis.pipeline() as pipe: - pipe.hset( - key, - mapping={ - str(idx): dumps(generation) - for idx, generation in enumerate(return_val) - }, - ) - if self.ttl is not None: - pipe.expire(key, self.ttl) - + self._configure_pipeline_for_update(key, pipe, return_val, self.ttl) pipe.execute() def clear(self, **kwargs: Any) -> None: @@ -440,6 +459,89 @@ class RedisCache(BaseCache): self.redis.flushdb(asynchronous=asynchronous, **kwargs) +class AsyncRedisCache(_RedisCacheBase): + """ + Cache that uses Redis as a backend. Allows to use an + async `redis.asyncio.Redis` client. + """ + + def __init__(self, redis_: Any, *, ttl: Optional[int] = None): + """ + Initialize an instance of AsyncRedisCache. + + This method initializes an object with Redis caching capabilities. + It takes a `redis_` parameter, which should be an instance of a Redis + client class (`redis.asyncio.Redis`), allowing the object + to interact with a Redis server for caching purposes. + + Parameters: + redis_ (Any): An instance of a Redis client class + (`redis.asyncio.Redis`) to be used for caching. + This allows the object to communicate with a + Redis server for caching operations. + ttl (int, optional): Time-to-live (TTL) for cached items in seconds. + If provided, it sets the time duration for how long cached + items will remain valid. If not provided, cached items will not + have an automatic expiration. + """ + try: + from redis.asyncio import Redis + except ImportError: + raise ValueError( + "Could not import `redis.asyncio` python package. " + "Please install it with `pip install redis`." + ) + if not isinstance(redis_, Redis): + raise ValueError("Please pass a valid `redis.asyncio.Redis` client.") + self.redis = redis_ + self.ttl = ttl + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + raise NotImplementedError( + "This async Redis cache does not implement `lookup()` method. " + "Consider using the async `alookup()` version." + ) + + async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string. Async version.""" + results = await self.redis.hgetall(self._key(prompt, llm_string)) + return self._get_generations(results) # type: ignore[arg-type] + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + raise NotImplementedError( + "This async Redis cache does not implement `update()` method. " + "Consider using the async `aupdate()` version." + ) + + async def aupdate( + self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE + ) -> None: + """Update cache based on prompt and llm_string. Async version.""" + self._ensure_generation_type(return_val) + key = self._key(prompt, llm_string) + + async with self.redis.pipeline() as pipe: + self._configure_pipeline_for_update(key, pipe, return_val, self.ttl) + await pipe.execute() # type: ignore[attr-defined] + + def clear(self, **kwargs: Any) -> None: + """Clear cache. If `asynchronous` is True, flush asynchronously.""" + raise NotImplementedError( + "This async Redis cache does not implement `clear()` method. " + "Consider using the async `aclear()` version." + ) + + async def aclear(self, **kwargs: Any) -> None: + """ + Clear cache. If `asynchronous` is True, flush asynchronously. + Async version. + """ + asynchronous = kwargs.get("asynchronous", False) + await self.redis.flushdb(asynchronous=asynchronous, **kwargs) + + class RedisSemanticCache(BaseCache): """Cache that uses Redis as a vector-store backend.""" diff --git a/libs/core/langchain_core/caches.py b/libs/core/langchain_core/caches.py index c14959c8f9c..626670950ab 100644 --- a/libs/core/langchain_core/caches.py +++ b/libs/core/langchain_core/caches.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any, Optional, Sequence from langchain_core.outputs import Generation +from langchain_core.runnables import run_in_executor RETURN_VAL_TYPE = Sequence[Generation] @@ -22,3 +23,17 @@ class BaseCache(ABC): @abstractmethod def clear(self, **kwargs: Any) -> None: """Clear cache that can take additional keyword arguments.""" + + async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + return await run_in_executor(None, self.lookup, prompt, llm_string) + + async def aupdate( + self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE + ) -> None: + """Update cache based on prompt and llm_string.""" + return await run_in_executor(None, self.update, prompt, llm_string, return_val) + + async def aclear(self, **kwargs: Any) -> None: + """Clear cache that can take additional keyword arguments.""" + return await run_in_executor(None, self.clear, **kwargs) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index aaf61a7810e..6279116093e 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -622,7 +622,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): else: llm_string = self._get_llm_string(stop=stop, **kwargs) prompt = dumps(messages) - cache_val = llm_cache.lookup(prompt, llm_string) + cache_val = await llm_cache.alookup(prompt, llm_string) if isinstance(cache_val, list): return ChatResult(generations=cache_val) else: @@ -632,7 +632,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) else: result = await self._agenerate(messages, stop=stop, **kwargs) - llm_cache.update(prompt, llm_string, result.generations) + await llm_cache.aupdate(prompt, llm_string, result.generations) return result @abstractmethod diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 7f987baf88d..ec2af1b91e0 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -139,6 +139,26 @@ def get_prompts( return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts +async def aget_prompts( + params: Dict[str, Any], prompts: List[str] +) -> Tuple[Dict[int, List], str, List[int], List[str]]: + """Get prompts that are already cached. Async version.""" + llm_string = str(sorted([(k, v) for k, v in params.items()])) + missing_prompts = [] + missing_prompt_idxs = [] + existing_prompts = {} + llm_cache = get_llm_cache() + for i, prompt in enumerate(prompts): + if llm_cache: + cache_val = await llm_cache.alookup(prompt, llm_string) + if isinstance(cache_val, list): + existing_prompts[i] = cache_val + else: + missing_prompts.append(prompt) + missing_prompt_idxs.append(i) + return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts + + def update_cache( existing_prompts: Dict[int, List], llm_string: str, @@ -157,6 +177,24 @@ def update_cache( return llm_output +async def aupdate_cache( + existing_prompts: Dict[int, List], + llm_string: str, + missing_prompt_idxs: List[int], + new_results: LLMResult, + prompts: List[str], +) -> Optional[dict]: + """Update the cache and get the LLM output. Async version""" + llm_cache = get_llm_cache() + for i, result in enumerate(new_results.generations): + existing_prompts[missing_prompt_idxs[i]] = result + prompt = prompts[missing_prompt_idxs[i]] + if llm_cache: + await llm_cache.aupdate(prompt, llm_string, result) + llm_output = new_results.llm_output + return llm_output + + class BaseLLM(BaseLanguageModel[str], ABC): """Base LLM abstract interface. @@ -869,7 +907,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): llm_string, missing_prompt_idxs, missing_prompts, - ) = get_prompts(params, prompts) + ) = await aget_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache new_arg_supported = inspect.signature(self._agenerate).parameters.get( "run_manager" @@ -917,7 +955,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): new_results = await self._agenerate_helper( missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) - llm_output = update_cache( + llm_output = await aupdate_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) run_info = ( diff --git a/libs/langchain/langchain/cache.py b/libs/langchain/langchain/cache.py index fae1d1cf032..3c249e964c5 100644 --- a/libs/langchain/langchain/cache.py +++ b/libs/langchain/langchain/cache.py @@ -1,6 +1,7 @@ from langchain_community.cache import ( AstraDBCache, AstraDBSemanticCache, + AsyncRedisCache, CassandraCache, CassandraSemanticCache, FullLLMCache, @@ -22,6 +23,7 @@ __all__ = [ "SQLAlchemyCache", "SQLiteCache", "UpstashRedisCache", + "AsyncRedisCache", "RedisCache", "RedisSemanticCache", "GPTCache", diff --git a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py index 7b44ade026d..846c709b971 100644 --- a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py +++ b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py @@ -1,6 +1,7 @@ """Test Redis cache functionality.""" import uuid -from typing import List, cast +from contextlib import asynccontextmanager, contextmanager +from typing import AsyncGenerator, Generator, List, Optional, cast import pytest from langchain_core.embeddings import Embeddings @@ -8,7 +9,7 @@ from langchain_core.load.dump import dumps from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, Generation, LLMResult -from langchain.cache import RedisCache, RedisSemanticCache +from langchain.cache import AsyncRedisCache, RedisCache, RedisSemanticCache from langchain.globals import get_llm_cache, set_llm_cache from tests.integration_tests.cache.fake_embeddings import ( ConsistentFakeEmbeddings, @@ -17,65 +18,176 @@ from tests.integration_tests.cache.fake_embeddings import ( from tests.unit_tests.llms.fake_chat_model import FakeChatModel from tests.unit_tests.llms.fake_llm import FakeLLM -REDIS_TEST_URL = "redis://localhost:6379" +# Using a non-standard port to avoid conflicts with potentially local running +# redis instances +# You can spin up a local redis using docker compose +# cd [repository-root]/docker +# docker-compose up redis +REDIS_TEST_URL = "redis://localhost:6020" def random_string() -> str: return str(uuid.uuid4()) +@contextmanager +def get_sync_redis(*, ttl: Optional[int] = 1) -> Generator[RedisCache, None, None]: + """Get a sync RedisCache instance.""" + import redis + + cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=ttl) + try: + yield cache + finally: + cache.clear() + + +@asynccontextmanager +async def get_async_redis( + *, ttl: Optional[int] = 1 +) -> AsyncGenerator[AsyncRedisCache, None]: + """Get an async RedisCache instance.""" + from redis.asyncio import Redis + + cache = AsyncRedisCache(redis_=Redis.from_url(REDIS_TEST_URL), ttl=ttl) + try: + yield cache + finally: + await cache.aclear() + + def test_redis_cache_ttl() -> None: - import redis + from redis import Redis - set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1)) - llm_cache = cast(RedisCache, get_llm_cache()) - llm_cache.update("foo", "bar", [Generation(text="fizz")]) - key = llm_cache._key("foo", "bar") - assert llm_cache.redis.pttl(key) > 0 + with get_sync_redis() as llm_cache: + set_llm_cache(llm_cache) + llm_cache.update("foo", "bar", [Generation(text="fizz")]) + key = llm_cache._key("foo", "bar") + assert isinstance(llm_cache.redis, Redis) + assert llm_cache.redis.pttl(key) > 0 -def test_redis_cache() -> None: - import redis +async def test_async_redis_cache_ttl() -> None: + from redis.asyncio import Redis as AsyncRedis - set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))) - llm = FakeLLM() - params = llm.dict() - params["stop"] = None - llm_string = str(sorted([(k, v) for k, v in params.items()])) - get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) - output = llm.generate(["foo"]) - expected_output = LLMResult( - generations=[[Generation(text="fizz")]], - llm_output={}, - ) - assert output == expected_output - llm_cache = cast(RedisCache, get_llm_cache()) - llm_cache.redis.flushall() + async with get_async_redis() as redis_cache: + set_llm_cache(redis_cache) + llm_cache = cast(RedisCache, get_llm_cache()) + await llm_cache.aupdate("foo", "bar", [Generation(text="fizz")]) + key = llm_cache._key("foo", "bar") + assert isinstance(llm_cache.redis, AsyncRedis) + assert await llm_cache.redis.pttl(key) > 0 + + +def test_sync_redis_cache() -> None: + with get_sync_redis() as llm_cache: + set_llm_cache(llm_cache) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + llm_cache.update("prompt", llm_string, [Generation(text="fizz0")]) + output = llm.generate(["prompt"]) + expected_output = LLMResult( + generations=[[Generation(text="fizz0")]], + llm_output={}, + ) + assert output == expected_output + + +async def test_sync_in_async_redis_cache() -> None: + """Test the sync RedisCache invoked with async methods""" + with get_sync_redis() as llm_cache: + set_llm_cache(llm_cache) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + # llm_cache.update("meow", llm_string, [Generation(text="meow")]) + await llm_cache.aupdate("prompt", llm_string, [Generation(text="fizz1")]) + output = await llm.agenerate(["prompt"]) + expected_output = LLMResult( + generations=[[Generation(text="fizz1")]], + llm_output={}, + ) + assert output == expected_output + + +async def test_async_redis_cache() -> None: + async with get_async_redis() as redis_cache: + set_llm_cache(redis_cache) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + llm_cache = cast(RedisCache, get_llm_cache()) + await llm_cache.aupdate("prompt", llm_string, [Generation(text="fizz2")]) + output = await llm.agenerate(["prompt"]) + expected_output = LLMResult( + generations=[[Generation(text="fizz2")]], + llm_output={}, + ) + assert output == expected_output + + +async def test_async_in_sync_redis_cache() -> None: + async with get_async_redis() as redis_cache: + set_llm_cache(redis_cache) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + llm_cache = cast(RedisCache, get_llm_cache()) + with pytest.raises(NotImplementedError): + llm_cache.update("foo", llm_string, [Generation(text="fizz")]) def test_redis_cache_chat() -> None: - import redis + with get_sync_redis() as redis_cache: + set_llm_cache(redis_cache) + llm = FakeChatModel() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + prompt: List[BaseMessage] = [HumanMessage(content="foo")] + llm_cache = cast(RedisCache, get_llm_cache()) + llm_cache.update( + dumps(prompt), + llm_string, + [ChatGeneration(message=AIMessage(content="fizz"))], + ) + output = llm.generate([prompt]) + expected_output = LLMResult( + generations=[[ChatGeneration(message=AIMessage(content="fizz"))]], + llm_output={}, + ) + assert output == expected_output - set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))) - llm = FakeChatModel() - params = llm.dict() - params["stop"] = None - llm_string = str(sorted([(k, v) for k, v in params.items()])) - prompt: List[BaseMessage] = [HumanMessage(content="foo")] - get_llm_cache().update( - dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))] - ) - output = llm.generate([prompt]) - expected_output = LLMResult( - generations=[[ChatGeneration(message=AIMessage(content="fizz"))]], - llm_output={}, - ) - assert output == expected_output - llm_cache = cast(RedisCache, get_llm_cache()) - llm_cache.redis.flushall() + +async def test_async_redis_cache_chat() -> None: + async with get_async_redis() as redis_cache: + set_llm_cache(redis_cache) + llm = FakeChatModel() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + prompt: List[BaseMessage] = [HumanMessage(content="foo")] + llm_cache = cast(RedisCache, get_llm_cache()) + await llm_cache.aupdate( + dumps(prompt), + llm_string, + [ChatGeneration(message=AIMessage(content="fizz"))], + ) + output = await llm.agenerate([prompt]) + expected_output = LLMResult( + generations=[[ChatGeneration(message=AIMessage(content="fizz"))]], + llm_output={}, + ) + assert output == expected_output def test_redis_semantic_cache() -> None: + """Test redis semantic cache functionality.""" set_llm_cache( RedisSemanticCache( embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1 @@ -85,7 +197,8 @@ def test_redis_semantic_cache() -> None: params = llm.dict() params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) - get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) + llm_cache = cast(RedisSemanticCache, get_llm_cache()) + llm_cache.update("foo", llm_string, [Generation(text="fizz")]) output = llm.generate( ["bar"] ) # foo and bar will have the same embedding produced by FakeEmbeddings @@ -95,13 +208,13 @@ def test_redis_semantic_cache() -> None: ) assert output == expected_output # clear the cache - get_llm_cache().clear(llm_string=llm_string) + llm_cache.clear(llm_string=llm_string) output = llm.generate( ["bar"] ) # foo and bar will have the same embedding produced by FakeEmbeddings # expect different output now without cached result assert output != expected_output - get_llm_cache().clear(llm_string=llm_string) + llm_cache.clear(llm_string=llm_string) def test_redis_semantic_cache_multi() -> None: @@ -114,7 +227,8 @@ def test_redis_semantic_cache_multi() -> None: params = llm.dict() params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) - get_llm_cache().update( + llm_cache = cast(RedisSemanticCache, get_llm_cache()) + llm_cache.update( "foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")] ) output = llm.generate( @@ -126,7 +240,7 @@ def test_redis_semantic_cache_multi() -> None: ) assert output == expected_output # clear the cache - get_llm_cache().clear(llm_string=llm_string) + llm_cache.clear(llm_string=llm_string) def test_redis_semantic_cache_chat() -> None: @@ -140,7 +254,8 @@ def test_redis_semantic_cache_chat() -> None: params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) prompt: List[BaseMessage] = [HumanMessage(content="foo")] - get_llm_cache().update( + llm_cache = cast(RedisSemanticCache, get_llm_cache()) + llm_cache.update( dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))] ) output = llm.generate([prompt]) @@ -149,7 +264,7 @@ def test_redis_semantic_cache_chat() -> None: llm_output={}, ) assert output == expected_output - get_llm_cache().clear(llm_string=llm_string) + llm_cache.clear(llm_string=llm_string) @pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()]) @@ -192,10 +307,11 @@ def test_redis_semantic_cache_hit( ] for prompt_i_generations in generations ] + llm_cache = cast(RedisSemanticCache, get_llm_cache()) for prompt_i, llm_generations_i in zip(prompts, llm_generations): print(prompt_i) print(llm_generations_i) - get_llm_cache().update(prompt_i, llm_string, llm_generations_i) + llm_cache.update(prompt_i, llm_string, llm_generations_i) llm.generate(prompts) assert llm.generate(prompts) == LLMResult( generations=llm_generations, llm_output={} diff --git a/libs/langchain/tests/unit_tests/test_cache.py b/libs/langchain/tests/unit_tests/test_cache.py index 88260a6f71c..5b2e1a4da4e 100644 --- a/libs/langchain/tests/unit_tests/test_cache.py +++ b/libs/langchain/tests/unit_tests/test_cache.py @@ -1,4 +1,5 @@ """Test caching for LLMs and ChatModels.""" +import sqlite3 from typing import Dict, Generator, List, Union import pytest @@ -21,7 +22,11 @@ from langchain.globals import get_llm_cache, set_llm_cache def get_sqlite_cache() -> SQLAlchemyCache: - return SQLAlchemyCache(engine=create_engine("sqlite://")) + return SQLAlchemyCache( + engine=create_engine( + "sqlite://", creator=lambda: sqlite3.connect("file::memory:?cache=shared") + ) + ) CACHE_OPTIONS = [ @@ -35,33 +40,41 @@ def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, Non # Will be run before each test cache_instance = request.param set_llm_cache(cache_instance()) - if get_llm_cache(): - get_llm_cache().clear() + if llm_cache := get_llm_cache(): + llm_cache.clear() else: raise ValueError("Cache not set. This should never happen.") yield # Will be run after each test - if get_llm_cache(): - get_llm_cache().clear() + if llm_cache: + llm_cache.clear() set_llm_cache(None) else: raise ValueError("Cache not set. This should never happen.") -def test_llm_caching() -> None: +async def test_llm_caching() -> None: prompt = "How are you?" response = "Test response" cached_response = "Cached test response" llm = FakeListLLM(responses=[response]) - if get_llm_cache(): - get_llm_cache().update( + if llm_cache := get_llm_cache(): + # sync test + llm_cache.update( prompt=prompt, llm_string=create_llm_string(llm), return_val=[Generation(text=cached_response)], ) assert llm(prompt) == cached_response + # async test + await llm_cache.aupdate( + prompt=prompt, + llm_string=create_llm_string(llm), + return_val=[Generation(text=cached_response)], + ) + assert await llm.ainvoke(prompt) == cached_response else: raise ValueError( "The cache not set. This should never happen, as the pytest fixture " @@ -90,14 +103,15 @@ def test_old_sqlite_llm_caching() -> None: assert llm(prompt) == cached_response -def test_chat_model_caching() -> None: +async def test_chat_model_caching() -> None: prompt: List[BaseMessage] = [HumanMessage(content="How are you?")] response = "Test response" cached_response = "Cached test response" cached_message = AIMessage(content=cached_response) llm = FakeListChatModel(responses=[response]) - if get_llm_cache(): - get_llm_cache().update( + if llm_cache := get_llm_cache(): + # sync test + llm_cache.update( prompt=dumps(prompt), llm_string=llm._get_llm_string(), return_val=[ChatGeneration(message=cached_message)], @@ -105,6 +119,16 @@ def test_chat_model_caching() -> None: result = llm(prompt) assert isinstance(result, AIMessage) assert result.content == cached_response + + # async test + await llm_cache.aupdate( + prompt=dumps(prompt), + llm_string=llm._get_llm_string(), + return_val=[ChatGeneration(message=cached_message)], + ) + result = await llm.ainvoke(prompt) + assert isinstance(result, AIMessage) + assert result.content == cached_response else: raise ValueError( "The cache not set. This should never happen, as the pytest fixture " @@ -112,25 +136,38 @@ def test_chat_model_caching() -> None: ) -def test_chat_model_caching_params() -> None: +async def test_chat_model_caching_params() -> None: prompt: List[BaseMessage] = [HumanMessage(content="How are you?")] response = "Test response" cached_response = "Cached test response" cached_message = AIMessage(content=cached_response) llm = FakeListChatModel(responses=[response]) - if get_llm_cache(): - get_llm_cache().update( + if llm_cache := get_llm_cache(): + # sync test + llm_cache.update( prompt=dumps(prompt), llm_string=llm._get_llm_string(functions=[]), return_val=[ChatGeneration(message=cached_message)], ) result = llm(prompt, functions=[]) + result_no_params = llm(prompt) assert isinstance(result, AIMessage) assert result.content == cached_response - result_no_params = llm(prompt) assert isinstance(result_no_params, AIMessage) assert result_no_params.content == response + # async test + await llm_cache.aupdate( + prompt=dumps(prompt), + llm_string=llm._get_llm_string(functions=[]), + return_val=[ChatGeneration(message=cached_message)], + ) + result = await llm.ainvoke(prompt, functions=[]) + result_no_params = await llm.ainvoke(prompt) + assert isinstance(result, AIMessage) + assert result.content == cached_response + assert isinstance(result_no_params, AIMessage) + assert result_no_params.content == response else: raise ValueError( "The cache not set. This should never happen, as the pytest fixture " @@ -138,19 +175,31 @@ def test_chat_model_caching_params() -> None: ) -def test_llm_cache_clear() -> None: +async def test_llm_cache_clear() -> None: prompt = "How are you?" - response = "Test response" + expected_response = "Test response" cached_response = "Cached test response" - llm = FakeListLLM(responses=[response]) - if get_llm_cache(): - get_llm_cache().update( + llm = FakeListLLM(responses=[expected_response]) + if llm_cache := get_llm_cache(): + # sync test + llm_cache.update( prompt=prompt, llm_string=create_llm_string(llm), return_val=[Generation(text=cached_response)], ) - get_llm_cache().clear() - assert llm(prompt) == response + llm_cache.clear() + response = llm(prompt) + assert response == expected_response + + # async test + await llm_cache.aupdate( + prompt=prompt, + llm_string=create_llm_string(llm), + return_val=[Generation(text=cached_response)], + ) + await llm_cache.aclear() + response = await llm.ainvoke(prompt) + assert response == expected_response else: raise ValueError( "The cache not set. This should never happen, as the pytest fixture "