diff --git a/libs/core/langchain_core/caches.py b/libs/core/langchain_core/caches.py index 8b9f63be474..1c46b0489b4 100644 --- a/libs/core/langchain_core/caches.py +++ b/libs/core/langchain_core/caches.py @@ -145,9 +145,18 @@ class BaseCache(ABC): class InMemoryCache(BaseCache): """Cache that stores things in memory.""" - def __init__(self) -> None: - """Initialize with empty cache.""" + def __init__(self, *, maxsize: Optional[int] = None) -> None: + """Initialize with empty cache. + + Args: + maxsize: The maximum number of items to store in the cache. + If None, the cache has no maximum size. + If the cache exceeds the maximum size, the oldest items are removed. + """ self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} + if maxsize is not None and maxsize <= 0: + raise ValueError("maxsize must be greater than 0") + self._maxsize = maxsize def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string. @@ -174,6 +183,8 @@ class InMemoryCache(BaseCache): return_val: The value to be cached. The value is a list of Generations (or subclasses). """ + if self._maxsize is not None and len(self._cache) == self._maxsize: + del self._cache[next(iter(self._cache))] self._cache[(prompt, llm_string)] = return_val def clear(self, **kwargs: Any) -> None: diff --git a/libs/core/tests/unit_tests/caches/test_in_memory_cache.py b/libs/core/tests/unit_tests/caches/test_in_memory_cache.py new file mode 100644 index 00000000000..67143c0ff92 --- /dev/null +++ b/libs/core/tests/unit_tests/caches/test_in_memory_cache.py @@ -0,0 +1,121 @@ +from typing import Tuple + +import pytest + +from langchain_core.caches import RETURN_VAL_TYPE, InMemoryCache +from langchain_core.outputs import Generation + + +@pytest.fixture +def cache() -> InMemoryCache: + """Fixture to provide an instance of InMemoryCache.""" + return InMemoryCache() + + +def cache_item(item_id: int) -> Tuple[str, str, RETURN_VAL_TYPE]: + """Generate a valid cache item.""" + prompt = f"prompt{item_id}" + llm_string = f"llm_string{item_id}" + generations = [Generation(text=f"text{item_id}")] + return prompt, llm_string, generations + + +def test_initialization() -> None: + """Test the initialization of InMemoryCache.""" + cache = InMemoryCache() + assert cache._cache == {} + assert cache._maxsize is None + + cache_with_maxsize = InMemoryCache(maxsize=2) + assert cache_with_maxsize._cache == {} + assert cache_with_maxsize._maxsize == 2 + + with pytest.raises(ValueError): + InMemoryCache(maxsize=0) + + +def test_lookup( + cache: InMemoryCache, +) -> None: + """Test the lookup method of InMemoryCache.""" + prompt, llm_string, generations = cache_item(1) + cache.update(prompt, llm_string, generations) + assert cache.lookup(prompt, llm_string) == generations + assert cache.lookup("prompt2", "llm_string2") is None + + +def test_update_with_no_maxsize(cache: InMemoryCache) -> None: + """Test the update method of InMemoryCache with no maximum size.""" + prompt, llm_string, generations = cache_item(1) + cache.update(prompt, llm_string, generations) + assert cache.lookup(prompt, llm_string) == generations + + +def test_update_with_maxsize() -> None: + """Test the update method of InMemoryCache with a maximum size.""" + cache = InMemoryCache(maxsize=2) + + prompt1, llm_string1, generations1 = cache_item(1) + cache.update(prompt1, llm_string1, generations1) + assert cache.lookup(prompt1, llm_string1) == generations1 + + prompt2, llm_string2, generations2 = cache_item(2) + cache.update(prompt2, llm_string2, generations2) + assert cache.lookup(prompt2, llm_string2) == generations2 + + prompt3, llm_string3, generations3 = cache_item(3) + cache.update(prompt3, llm_string3, generations3) + + assert cache.lookup(prompt1, llm_string1) is None # 'prompt1' should be evicted + assert cache.lookup(prompt2, llm_string2) == generations2 + assert cache.lookup(prompt3, llm_string3) == generations3 + + +def test_clear(cache: InMemoryCache) -> None: + """Test the clear method of InMemoryCache.""" + prompt, llm_string, generations = cache_item(1) + cache.update(prompt, llm_string, generations) + cache.clear() + assert cache.lookup(prompt, llm_string) is None + + +async def test_alookup(cache: InMemoryCache) -> None: + """Test the asynchronous lookup method of InMemoryCache.""" + prompt, llm_string, generations = cache_item(1) + await cache.aupdate(prompt, llm_string, generations) + assert await cache.alookup(prompt, llm_string) == generations + assert await cache.alookup("prompt2", "llm_string2") is None + + +async def test_aupdate_with_no_maxsize(cache: InMemoryCache) -> None: + """Test the asynchronous update method of InMemoryCache with no maximum size.""" + prompt, llm_string, generations = cache_item(1) + await cache.aupdate(prompt, llm_string, generations) + assert await cache.alookup(prompt, llm_string) == generations + + +async def test_aupdate_with_maxsize() -> None: + """Test the asynchronous update method of InMemoryCache with a maximum size.""" + cache = InMemoryCache(maxsize=2) + prompt, llm_string, generations = cache_item(1) + await cache.aupdate(prompt, llm_string, generations) + assert await cache.alookup(prompt, llm_string) == generations + + prompt2, llm_string2, generations2 = cache_item(2) + await cache.aupdate(prompt2, llm_string2, generations2) + assert await cache.alookup(prompt2, llm_string2) == generations2 + + prompt3, llm_string3, generations3 = cache_item(3) + await cache.aupdate(prompt3, llm_string3, generations3) + + assert await cache.alookup(prompt, llm_string) is None + assert await cache.alookup(prompt2, llm_string2) == generations2 + assert await cache.alookup(prompt3, llm_string3) == generations3 + + +async def test_aclear(cache: InMemoryCache) -> None: + """Test the asynchronous clear method of InMemoryCache.""" + prompt, llm_string, generations = cache_item(1) + await cache.aupdate(prompt, llm_string, generations) + await cache.aclear() + assert await cache.alookup(prompt, llm_string) is None