diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index eef6263983b..c11cf4e1b79 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -31,6 +31,7 @@ from langchain_core.runnables import Runnable, RunnableSerializable from langchain_core.utils import get_pydantic_field_names if TYPE_CHECKING: + from langchain_core.caches import BaseCache from langchain_core.callbacks import Callbacks from langchain_core.outputs import LLMResult @@ -78,8 +79,16 @@ class BaseLanguageModel( All language model wrappers inherit from BaseLanguageModel. """ - cache: Optional[bool] = None - """Whether to cache the response.""" + cache: Union[BaseCache, bool, None] = None + """Whether to cache the response. + + * If true, will use the global cache. + * If false, will not use a cache + * If None, will use the global cache if it's set, otherwise no cache. + * If instance of BaseCache, will use the provided cache. + + Caching is not currently supported for streaming methods of models. + """ verbose: bool = Field(default_factory=_get_verbosity) """Whether to print out response text.""" callbacks: Callbacks = Field(default=None, exclude=True) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 8472dbcc033..cfadfafa338 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -21,6 +21,7 @@ from typing import ( ) from langchain_core._api import deprecated +from langchain_core.caches import BaseCache from langchain_core.callbacks import ( AsyncCallbackManager, AsyncCallbackManagerForLLMRun, @@ -596,7 +597,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - llm_cache = get_llm_cache() + if isinstance(self.cache, BaseCache): + llm_cache = self.cache + else: + llm_cache = get_llm_cache() + # We should check the cache unless it's explicitly set to False + # A None cache means we should use the default global cache + # if it's configured. check_cache = self.cache or self.cache is None if check_cache: if llm_cache: @@ -618,6 +625,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): else: result = self._generate(messages, stop=stop, **kwargs) + # Add response metadata to each generation for generation in result.generations: generation.message.response_metadata = _gen_info_and_msg_metadata( generation @@ -638,7 +646,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - llm_cache = get_llm_cache() + if isinstance(self.cache, BaseCache): + llm_cache = self.cache + else: + llm_cache = get_llm_cache() + # We should check the cache unless it's explicitly set to False + # A None cache means we should use the default global cache + # if it's configured. check_cache = self.cache or self.cache is None if check_cache: if llm_cache: @@ -659,6 +673,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) else: result = await self._agenerate(messages, stop=stop, **kwargs) + + # Add response metadata to each generation for generation in result.generations: generation.message.response_metadata = _gen_info_and_msg_metadata( generation diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 789d9baa2ab..9f7ba5ce1ee 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -39,6 +39,7 @@ from tenacity import ( ) from langchain_core._api import deprecated +from langchain_core.caches import BaseCache from langchain_core.callbacks import ( AsyncCallbackManager, AsyncCallbackManagerForLLMRun, @@ -733,6 +734,10 @@ class BaseLLM(BaseLanguageModel[str], ABC): missing_prompt_idxs, missing_prompts, ) = get_prompts(params, prompts) + if isinstance(self.cache, BaseCache): + raise NotImplementedError( + "Local cache is not yet supported for " "LLMs (only chat models)" + ) disregard_cache = self.cache is not None and not self.cache new_arg_supported = inspect.signature(self._generate).parameters.get( "run_manager" @@ -942,6 +947,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): missing_prompt_idxs, missing_prompts, ) = await aget_prompts(params, prompts) + if isinstance(self.cache, BaseCache): + raise NotImplementedError( + "Local cache is not yet supported for " "LLMs (only chat models)" + ) + disregard_cache = self.cache is not None and not self.cache new_arg_supported = inspect.signature(self._agenerate).parameters.get( "run_manager" diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py new file mode 100644 index 00000000000..a2cc4fc4591 --- /dev/null +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -0,0 +1,268 @@ +"""Module tests interaction of chat model with caching abstraction..""" +from typing import Any, Dict, Optional, Tuple + +import pytest + +from langchain_core.caches import RETURN_VAL_TYPE, BaseCache +from langchain_core.globals import set_llm_cache +from langchain_core.language_models.fake_chat_models import ( + FakeListChatModel, + GenericFakeChatModel, +) +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration + + +class InMemoryCache(BaseCache): + """In-memory cache used for testing purposes.""" + + def __init__(self) -> None: + """Initialize with empty cache.""" + self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + return self._cache.get((prompt, llm_string), None) + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + self._cache[(prompt, llm_string)] = return_val + + def clear(self, **kwargs: Any) -> None: + """Clear cache.""" + self._cache = {} + + +def test_local_cache_sync() -> None: + """Test that the local cache is being populated but not the global one.""" + global_cache = InMemoryCache() + local_cache = InMemoryCache() + try: + set_llm_cache(global_cache) + chat_model = FakeListChatModel( + cache=local_cache, responses=["hello", "goodbye"] + ) + assert chat_model.invoke("How are you?").content == "hello" + # If the cache works we should get the same response since + # the prompt is the same + assert chat_model.invoke("How are you?").content == "hello" + # The global cache should be empty + assert global_cache._cache == {} + # The local cache should be populated + assert len(local_cache._cache) == 1 + llm_result = list(local_cache._cache.values()) + chat_generation = llm_result[0][0] + assert isinstance(chat_generation, ChatGeneration) + assert chat_generation.message.content == "hello" + # Verify that another prompt will trigger the call to the model + assert chat_model.invoke("meow?").content == "goodbye" + # The global cache should be empty + assert global_cache._cache == {} + # The local cache should be populated + assert len(local_cache._cache) == 2 + finally: + set_llm_cache(None) + + +async def test_local_cache_async() -> None: + # Use MockCache as the cache + global_cache = InMemoryCache() + local_cache = InMemoryCache() + try: + set_llm_cache(global_cache) + chat_model = FakeListChatModel( + cache=local_cache, responses=["hello", "goodbye"] + ) + assert (await chat_model.ainvoke("How are you?")).content == "hello" + # If the cache works we should get the same response since + # the prompt is the same + assert (await chat_model.ainvoke("How are you?")).content == "hello" + # The global cache should be empty + assert global_cache._cache == {} + # The local cache should be populated + assert len(local_cache._cache) == 1 + llm_result = list(local_cache._cache.values()) + chat_generation = llm_result[0][0] + assert isinstance(chat_generation, ChatGeneration) + assert chat_generation.message.content == "hello" + # Verify that another prompt will trigger the call to the model + assert chat_model.invoke("meow?").content == "goodbye" + # The global cache should be empty + assert global_cache._cache == {} + # The local cache should be populated + assert len(local_cache._cache) == 2 + finally: + set_llm_cache(None) + + +def test_global_cache_sync() -> None: + """Test that the global cache gets populated when cache = True.""" + global_cache = InMemoryCache() + try: + set_llm_cache(global_cache) + chat_model = FakeListChatModel( + cache=True, responses=["hello", "goodbye", "meow", "woof"] + ) + assert (chat_model.invoke("How are you?")).content == "hello" + # If the cache works we should get the same response since + # the prompt is the same + assert (chat_model.invoke("How are you?")).content == "hello" + # The global cache should be populated + assert len(global_cache._cache) == 1 + llm_result = list(global_cache._cache.values()) + chat_generation = llm_result[0][0] + assert isinstance(chat_generation, ChatGeneration) + assert chat_generation.message.content == "hello" + # Verify that another prompt will trigger the call to the model + assert chat_model.invoke("nice").content == "goodbye" + # The local cache should be populated + assert len(global_cache._cache) == 2 + finally: + set_llm_cache(None) + + +async def test_global_cache_async() -> None: + """Test that the global cache gets populated when cache = True.""" + global_cache = InMemoryCache() + try: + set_llm_cache(global_cache) + chat_model = FakeListChatModel( + cache=True, responses=["hello", "goodbye", "meow", "woof"] + ) + assert (await chat_model.ainvoke("How are you?")).content == "hello" + # If the cache works we should get the same response since + # the prompt is the same + assert (await chat_model.ainvoke("How are you?")).content == "hello" + # The global cache should be populated + assert len(global_cache._cache) == 1 + llm_result = list(global_cache._cache.values()) + chat_generation = llm_result[0][0] + assert isinstance(chat_generation, ChatGeneration) + assert chat_generation.message.content == "hello" + # Verify that another prompt will trigger the call to the model + assert chat_model.invoke("nice").content == "goodbye" + # The local cache should be populated + assert len(global_cache._cache) == 2 + finally: + set_llm_cache(None) + + +def test_no_cache_sync() -> None: + global_cache = InMemoryCache() + try: + set_llm_cache(global_cache) + chat_model = FakeListChatModel( + cache=False, responses=["hello", "goodbye"] + ) # Set cache=False + assert (chat_model.invoke("How are you?")).content == "hello" + # The global cache should not be populated since cache=False + # so we should get the second response + assert (chat_model.invoke("How are you?")).content == "goodbye" + # The global cache should not be populated since cache=False + assert len(global_cache._cache) == 0 + finally: + set_llm_cache(None) + + +async def test_no_cache_async() -> None: + global_cache = InMemoryCache() + try: + set_llm_cache(global_cache) + chat_model = FakeListChatModel( + cache=False, responses=["hello", "goodbye"] + ) # Set cache=False + assert (await chat_model.ainvoke("How are you?")).content == "hello" + # The global cache should not be populated since cache=False + # so we should get the second response + assert (await chat_model.ainvoke("How are you?")).content == "goodbye" + # The global cache should not be populated since cache=False + assert len(global_cache._cache) == 0 + finally: + set_llm_cache(None) + + +async def test_global_cache_abatch() -> None: + global_cache = InMemoryCache() + try: + set_llm_cache(global_cache) + chat_model = FakeListChatModel( + cache=True, responses=["hello", "goodbye", "meow", "woof"] + ) + results = await chat_model.abatch(["first prompt", "second prompt"]) + assert results[0].content == "hello" + assert results[1].content == "goodbye" + + # Now try with the same prompt + results = await chat_model.abatch(["first prompt", "first prompt"]) + assert results[0].content == "hello" + assert results[1].content == "hello" + + ## RACE CONDITION -- note behavior is different from sync + # Now, reset cache and test the race condition + # For now we just hard-code the result, if this changes + # we can investigate further + global_cache = InMemoryCache() + set_llm_cache(global_cache) + assert global_cache._cache == {} + results = await chat_model.abatch(["prompt", "prompt"]) + # suspecting that tasks will be scheduled and executed in order + # if this ever fails, we can relax to a set comparison + # Cache misses likely guaranteed? + assert results[0].content == "meow" + assert results[1].content == "woof" + finally: + set_llm_cache(None) + + +def test_global_cache_batch() -> None: + global_cache = InMemoryCache() + try: + set_llm_cache(global_cache) + chat_model = FakeListChatModel( + cache=True, responses=["hello", "goodbye", "meow", "woof"] + ) + results = chat_model.batch(["first prompt", "second prompt"]) + # These may be in any order + assert {results[0].content, results[1].content} == {"hello", "goodbye"} + + # Now try with the same prompt + results = chat_model.batch(["first prompt", "first prompt"]) + # These could be either "hello" or "goodbye" and should be identical + assert results[0].content == results[1].content + assert {results[0].content, results[1].content}.issubset({"hello", "goodbye"}) + + ## RACE CONDITION -- note behavior is different from async + # Now, reset cache and test the race condition + # For now we just hard-code the result, if this changes + # we can investigate further + global_cache = InMemoryCache() + set_llm_cache(global_cache) + assert global_cache._cache == {} + results = chat_model.batch( + [ + "prompt", + "prompt", + ] + ) + assert {results[0].content, results[1].content} == {"meow"} + finally: + set_llm_cache(None) + + +@pytest.mark.xfail(reason="Abstraction does not support caching for streaming yet.") +def test_global_cache_stream() -> None: + """Test streaming.""" + global_cache = InMemoryCache() + try: + set_llm_cache(global_cache) + messages = [ + AIMessage(content="hello world"), + AIMessage(content="goodbye world"), + ] + model = GenericFakeChatModel(messages=iter(messages), cache=True) + chunks = [chunk for chunk in model.stream("some input")] + assert len(chunks) == 3 + # Assert that streaming information gets cached + assert global_cache._cache != {} + finally: + set_llm_cache(None)