mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
feat: add Momento as a standard cache and chat message history provider (#5221)
# Add Momento as a standard cache and chat message history provider This PR adds Momento as a standard caching provider. Implements the interface, adds integration tests, and documentation. We also add Momento as a chat history message provider along with integration tests, and documentation. [Momento](https://www.gomomento.com/) is a fully serverless cache. Similar to S3 or DynamoDB, it requires zero configuration, infrastructure management, and is instantly available. Users sign up for free and get 50GB of data in/out for free every month. ## Before submitting ✅ We have added documentation, notebooks, and integration tests demonstrating usage. Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
94
tests/integration_tests/cache/test_momento_cache.py
vendored
Normal file
94
tests/integration_tests/cache/test_momento_cache.py
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Test Momento cache functionality.
|
||||
|
||||
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
|
||||
Momento auth token. This can be obtained by signing up for a free
|
||||
Momento account at https://gomomento.com/.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
|
||||
import langchain
|
||||
from langchain.cache import MomentoCache
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def momento_cache() -> Iterator[MomentoCache]:
|
||||
cache_name = f"langchain-test-cache-{random_string()}"
|
||||
client = CacheClient(
|
||||
Configurations.Laptop.v1(),
|
||||
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||
default_ttl=timedelta(seconds=30),
|
||||
)
|
||||
try:
|
||||
llm_cache = MomentoCache(client, cache_name)
|
||||
langchain.llm_cache = llm_cache
|
||||
yield llm_cache
|
||||
finally:
|
||||
client.delete_cache(cache_name)
|
||||
|
||||
|
||||
def test_invalid_ttl() -> None:
|
||||
client = CacheClient(
|
||||
Configurations.Laptop.v1(),
|
||||
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||
default_ttl=timedelta(seconds=30),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
MomentoCache(client, cache_name=random_string(), ttl=timedelta(seconds=-1))
|
||||
|
||||
|
||||
def test_momento_cache_miss(momento_cache: MomentoCache) -> None:
|
||||
llm = FakeLLM()
|
||||
stub_llm_output = LLMResult(generations=[[Generation(text="foo")]])
|
||||
assert llm.generate([random_string()]) == stub_llm_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompts, generations",
|
||||
[
|
||||
# Single prompt, single generation
|
||||
([random_string()], [[random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string(), random_string()]]),
|
||||
# Multiple prompts, multiple generations
|
||||
(
|
||||
[random_string(), random_string()],
|
||||
[[random_string()], [random_string(), random_string()]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_momento_cache_hit(
|
||||
momento_cache: MomentoCache, prompts: list[str], generations: list[list[str]]
|
||||
) -> None:
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
llm_generations = [
|
||||
[
|
||||
Generation(text=generation, generation_info=params)
|
||||
for generation in prompt_i_generations
|
||||
]
|
||||
for prompt_i_generations in generations
|
||||
]
|
||||
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
||||
momento_cache.update(prompt_i, llm_string, llm_generations_i)
|
||||
|
||||
assert llm.generate(prompts) == LLMResult(
|
||||
generations=llm_generations, llm_output={}
|
||||
)
|
70
tests/integration_tests/memory/test_momento.py
Normal file
70
tests/integration_tests/memory/test_momento.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Test Momento chat message history functionality.
|
||||
|
||||
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
|
||||
Momento auth token. This can be obtained by signing up for a free
|
||||
Momento account at https://gomomento.com/.
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import MomentoChatMessageHistory
|
||||
from langchain.schema import _message_to_dict
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def message_history() -> Iterator[MomentoChatMessageHistory]:
|
||||
cache_name = f"langchain-test-cache-{random_string()}"
|
||||
client = CacheClient(
|
||||
Configurations.Laptop.v1(),
|
||||
CredentialProvider.from_environment_variable("MOMENTO_AUTH_TOKEN"),
|
||||
default_ttl=timedelta(seconds=30),
|
||||
)
|
||||
try:
|
||||
chat_message_history = MomentoChatMessageHistory(
|
||||
session_id="my-test-session",
|
||||
cache_client=client,
|
||||
cache_name=cache_name,
|
||||
)
|
||||
yield chat_message_history
|
||||
finally:
|
||||
client.delete_cache(cache_name)
|
||||
|
||||
|
||||
def test_memory_empty_on_new_session(
|
||||
message_history: MomentoChatMessageHistory,
|
||||
) -> None:
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="foo", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
|
||||
def test_memory_with_message_store(message_history: MomentoChatMessageHistory) -> None:
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# Add some messages to the memory store
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# Verify that the messages are in the store
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# Verify clearing the store
|
||||
memory.chat_memory.clear()
|
||||
assert memory.chat_memory.messages == []
|
Reference in New Issue
Block a user