mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
core[minor]: Enhance cache flexibility in BaseChatModel (#17386)
- **Description:** Enhanced the `BaseChatModel` to support an `Optional[Union[bool, BaseCache]]` type for the `cache` attribute, allowing for both boolean flags and custom cache implementations. Implemented logic within chat model methods to utilize the provided custom cache implementation effectively. This change aims to provide more flexibility in caching strategies for chat models. - **Issue:** Implements enhancement request #17242. - **Dependencies:** No additional dependencies required for this change. --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
4761c09e94
commit
50f93d86ec
@ -31,6 +31,7 @@ from langchain_core.runnables import Runnable, RunnableSerializable
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
@ -78,8 +79,16 @@ class BaseLanguageModel(
|
||||
All language model wrappers inherit from BaseLanguageModel.
|
||||
"""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
"""Whether to cache the response."""
|
||||
cache: Union[BaseCache, bool, None] = None
|
||||
"""Whether to cache the response.
|
||||
|
||||
* If true, will use the global cache.
|
||||
* If false, will not use a cache
|
||||
* If None, will use the global cache if it's set, otherwise no cache.
|
||||
* If instance of BaseCache, will use the provided cache.
|
||||
|
||||
Caching is not currently supported for streaming methods of models.
|
||||
"""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
|
@ -21,6 +21,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -596,7 +597,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
llm_cache = get_llm_cache()
|
||||
if isinstance(self.cache, BaseCache):
|
||||
llm_cache = self.cache
|
||||
else:
|
||||
llm_cache = get_llm_cache()
|
||||
# We should check the cache unless it's explicitly set to False
|
||||
# A None cache means we should use the default global cache
|
||||
# if it's configured.
|
||||
check_cache = self.cache or self.cache is None
|
||||
if check_cache:
|
||||
if llm_cache:
|
||||
@ -618,6 +625,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
|
||||
# Add response metadata to each generation
|
||||
for generation in result.generations:
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
@ -638,7 +646,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
llm_cache = get_llm_cache()
|
||||
if isinstance(self.cache, BaseCache):
|
||||
llm_cache = self.cache
|
||||
else:
|
||||
llm_cache = get_llm_cache()
|
||||
# We should check the cache unless it's explicitly set to False
|
||||
# A None cache means we should use the default global cache
|
||||
# if it's configured.
|
||||
check_cache = self.cache or self.cache is None
|
||||
if check_cache:
|
||||
if llm_cache:
|
||||
@ -659,6 +673,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
else:
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
|
||||
# Add response metadata to each generation
|
||||
for generation in result.generations:
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
|
@ -39,6 +39,7 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -733,6 +734,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
missing_prompt_idxs,
|
||||
missing_prompts,
|
||||
) = get_prompts(params, prompts)
|
||||
if isinstance(self.cache, BaseCache):
|
||||
raise NotImplementedError(
|
||||
"Local cache is not yet supported for " "LLMs (only chat models)"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
@ -942,6 +947,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
missing_prompt_idxs,
|
||||
missing_prompts,
|
||||
) = await aget_prompts(params, prompts)
|
||||
if isinstance(self.cache, BaseCache):
|
||||
raise NotImplementedError(
|
||||
"Local cache is not yet supported for " "LLMs (only chat models)"
|
||||
)
|
||||
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
|
@ -0,0 +1,268 @@
|
||||
"""Module tests interaction of chat model with caching abstraction.."""
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
||||
from langchain_core.globals import set_llm_cache
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModel,
|
||||
GenericFakeChatModel,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
"""In-memory cache used for testing purposes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize with empty cache."""
|
||||
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
return self._cache.get((prompt, llm_string), None)
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
self._cache[(prompt, llm_string)] = return_val
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache."""
|
||||
self._cache = {}
|
||||
|
||||
|
||||
def test_local_cache_sync() -> None:
|
||||
"""Test that the local cache is being populated but not the global one."""
|
||||
global_cache = InMemoryCache()
|
||||
local_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=local_cache, responses=["hello", "goodbye"]
|
||||
)
|
||||
assert chat_model.invoke("How are you?").content == "hello"
|
||||
# If the cache works we should get the same response since
|
||||
# the prompt is the same
|
||||
assert chat_model.invoke("How are you?").content == "hello"
|
||||
# The global cache should be empty
|
||||
assert global_cache._cache == {}
|
||||
# The local cache should be populated
|
||||
assert len(local_cache._cache) == 1
|
||||
llm_result = list(local_cache._cache.values())
|
||||
chat_generation = llm_result[0][0]
|
||||
assert isinstance(chat_generation, ChatGeneration)
|
||||
assert chat_generation.message.content == "hello"
|
||||
# Verify that another prompt will trigger the call to the model
|
||||
assert chat_model.invoke("meow?").content == "goodbye"
|
||||
# The global cache should be empty
|
||||
assert global_cache._cache == {}
|
||||
# The local cache should be populated
|
||||
assert len(local_cache._cache) == 2
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
async def test_local_cache_async() -> None:
|
||||
# Use MockCache as the cache
|
||||
global_cache = InMemoryCache()
|
||||
local_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=local_cache, responses=["hello", "goodbye"]
|
||||
)
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "hello"
|
||||
# If the cache works we should get the same response since
|
||||
# the prompt is the same
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "hello"
|
||||
# The global cache should be empty
|
||||
assert global_cache._cache == {}
|
||||
# The local cache should be populated
|
||||
assert len(local_cache._cache) == 1
|
||||
llm_result = list(local_cache._cache.values())
|
||||
chat_generation = llm_result[0][0]
|
||||
assert isinstance(chat_generation, ChatGeneration)
|
||||
assert chat_generation.message.content == "hello"
|
||||
# Verify that another prompt will trigger the call to the model
|
||||
assert chat_model.invoke("meow?").content == "goodbye"
|
||||
# The global cache should be empty
|
||||
assert global_cache._cache == {}
|
||||
# The local cache should be populated
|
||||
assert len(local_cache._cache) == 2
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
def test_global_cache_sync() -> None:
|
||||
"""Test that the global cache gets populated when cache = True."""
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=True, responses=["hello", "goodbye", "meow", "woof"]
|
||||
)
|
||||
assert (chat_model.invoke("How are you?")).content == "hello"
|
||||
# If the cache works we should get the same response since
|
||||
# the prompt is the same
|
||||
assert (chat_model.invoke("How are you?")).content == "hello"
|
||||
# The global cache should be populated
|
||||
assert len(global_cache._cache) == 1
|
||||
llm_result = list(global_cache._cache.values())
|
||||
chat_generation = llm_result[0][0]
|
||||
assert isinstance(chat_generation, ChatGeneration)
|
||||
assert chat_generation.message.content == "hello"
|
||||
# Verify that another prompt will trigger the call to the model
|
||||
assert chat_model.invoke("nice").content == "goodbye"
|
||||
# The local cache should be populated
|
||||
assert len(global_cache._cache) == 2
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
async def test_global_cache_async() -> None:
|
||||
"""Test that the global cache gets populated when cache = True."""
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=True, responses=["hello", "goodbye", "meow", "woof"]
|
||||
)
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "hello"
|
||||
# If the cache works we should get the same response since
|
||||
# the prompt is the same
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "hello"
|
||||
# The global cache should be populated
|
||||
assert len(global_cache._cache) == 1
|
||||
llm_result = list(global_cache._cache.values())
|
||||
chat_generation = llm_result[0][0]
|
||||
assert isinstance(chat_generation, ChatGeneration)
|
||||
assert chat_generation.message.content == "hello"
|
||||
# Verify that another prompt will trigger the call to the model
|
||||
assert chat_model.invoke("nice").content == "goodbye"
|
||||
# The local cache should be populated
|
||||
assert len(global_cache._cache) == 2
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
def test_no_cache_sync() -> None:
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=False, responses=["hello", "goodbye"]
|
||||
) # Set cache=False
|
||||
assert (chat_model.invoke("How are you?")).content == "hello"
|
||||
# The global cache should not be populated since cache=False
|
||||
# so we should get the second response
|
||||
assert (chat_model.invoke("How are you?")).content == "goodbye"
|
||||
# The global cache should not be populated since cache=False
|
||||
assert len(global_cache._cache) == 0
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
async def test_no_cache_async() -> None:
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=False, responses=["hello", "goodbye"]
|
||||
) # Set cache=False
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "hello"
|
||||
# The global cache should not be populated since cache=False
|
||||
# so we should get the second response
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "goodbye"
|
||||
# The global cache should not be populated since cache=False
|
||||
assert len(global_cache._cache) == 0
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
async def test_global_cache_abatch() -> None:
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=True, responses=["hello", "goodbye", "meow", "woof"]
|
||||
)
|
||||
results = await chat_model.abatch(["first prompt", "second prompt"])
|
||||
assert results[0].content == "hello"
|
||||
assert results[1].content == "goodbye"
|
||||
|
||||
# Now try with the same prompt
|
||||
results = await chat_model.abatch(["first prompt", "first prompt"])
|
||||
assert results[0].content == "hello"
|
||||
assert results[1].content == "hello"
|
||||
|
||||
## RACE CONDITION -- note behavior is different from sync
|
||||
# Now, reset cache and test the race condition
|
||||
# For now we just hard-code the result, if this changes
|
||||
# we can investigate further
|
||||
global_cache = InMemoryCache()
|
||||
set_llm_cache(global_cache)
|
||||
assert global_cache._cache == {}
|
||||
results = await chat_model.abatch(["prompt", "prompt"])
|
||||
# suspecting that tasks will be scheduled and executed in order
|
||||
# if this ever fails, we can relax to a set comparison
|
||||
# Cache misses likely guaranteed?
|
||||
assert results[0].content == "meow"
|
||||
assert results[1].content == "woof"
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
def test_global_cache_batch() -> None:
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=True, responses=["hello", "goodbye", "meow", "woof"]
|
||||
)
|
||||
results = chat_model.batch(["first prompt", "second prompt"])
|
||||
# These may be in any order
|
||||
assert {results[0].content, results[1].content} == {"hello", "goodbye"}
|
||||
|
||||
# Now try with the same prompt
|
||||
results = chat_model.batch(["first prompt", "first prompt"])
|
||||
# These could be either "hello" or "goodbye" and should be identical
|
||||
assert results[0].content == results[1].content
|
||||
assert {results[0].content, results[1].content}.issubset({"hello", "goodbye"})
|
||||
|
||||
## RACE CONDITION -- note behavior is different from async
|
||||
# Now, reset cache and test the race condition
|
||||
# For now we just hard-code the result, if this changes
|
||||
# we can investigate further
|
||||
global_cache = InMemoryCache()
|
||||
set_llm_cache(global_cache)
|
||||
assert global_cache._cache == {}
|
||||
results = chat_model.batch(
|
||||
[
|
||||
"prompt",
|
||||
"prompt",
|
||||
]
|
||||
)
|
||||
assert {results[0].content, results[1].content} == {"meow"}
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Abstraction does not support caching for streaming yet.")
|
||||
def test_global_cache_stream() -> None:
|
||||
"""Test streaming."""
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
messages = [
|
||||
AIMessage(content="hello world"),
|
||||
AIMessage(content="goodbye world"),
|
||||
]
|
||||
model = GenericFakeChatModel(messages=iter(messages), cache=True)
|
||||
chunks = [chunk for chunk in model.stream("some input")]
|
||||
assert len(chunks) == 3
|
||||
# Assert that streaming information gets cached
|
||||
assert global_cache._cache != {}
|
||||
finally:
|
||||
set_llm_cache(None)
|
Loading…
Reference in New Issue
Block a user