core[patch]: Use InMemoryChatMessageHistory in unit tests (#23916)

Update unit test to use the existing implementation of chat message
history
This commit is contained in:
Eugene Yurtsev 2024-07-05 16:10:54 -04:00 committed by GitHub
parent 8b84457b17
commit 9787552b00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 44 deletions

View File

@ -1,25 +0,0 @@
from typing import List
from langchain_core.chat_history import (
BaseChatMessageHistory,
)
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel, Field
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
"""In memory implementation of chat message history.
Stores messages in an in memory list.
"""
messages: List[BaseMessage] = Field(default_factory=list)
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
if not isinstance(message, BaseMessage):
raise ValueError
self.messages.append(message)
def clear(self) -> None:
self.messages = []

View File

@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from langchain_core.callbacks import ( from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.outputs import ChatGeneration, ChatResult
@ -11,11 +12,10 @@ from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.utils import ConfigurableFieldSpec from langchain_core.runnables.utils import ConfigurableFieldSpec
from tests.unit_tests.fake.memory import ChatMessageHistory
def test_interfaces() -> None: def test_interfaces() -> None:
history = ChatMessageHistory() history = InMemoryChatMessageHistory()
history.add_message(SystemMessage(content="system")) history.add_message(SystemMessage(content="system"))
history.add_user_message("human 1") history.add_user_message("human 1")
history.add_ai_message("ai") history.add_ai_message("ai")
@ -26,12 +26,14 @@ def test_interfaces() -> None:
def _get_get_session_history( def _get_get_session_history(
*, *,
store: Optional[Dict[str, Any]] = None, store: Optional[Dict[str, Any]] = None,
) -> Callable[..., ChatMessageHistory]: ) -> Callable[..., InMemoryChatMessageHistory]:
chat_history_store = store if store is not None else {} chat_history_store = store if store is not None else {}
def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory: def get_session_history(
session_id: str, **kwargs: Any
) -> InMemoryChatMessageHistory:
if session_id not in chat_history_store: if session_id not in chat_history_store:
chat_history_store[session_id] = ChatMessageHistory() chat_history_store[session_id] = InMemoryChatMessageHistory()
return chat_history_store[session_id] return chat_history_store[session_id]
return get_session_history return get_session_history
@ -51,7 +53,7 @@ def test_input_messages() -> None:
output = with_history.invoke([HumanMessage(content="good bye")], config) output = with_history.invoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye" assert output == "you said: hello\ngood bye"
assert store == { assert store == {
"1": ChatMessageHistory( "1": InMemoryChatMessageHistory(
messages=[ messages=[
HumanMessage(content="hello"), HumanMessage(content="hello"),
AIMessage(content="you said: hello"), AIMessage(content="you said: hello"),
@ -76,7 +78,7 @@ async def test_input_messages_async() -> None:
output = await with_history.ainvoke([HumanMessage(content="good bye")], config) output = await with_history.ainvoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye" assert output == "you said: hello\ngood bye"
assert store == { assert store == {
"1_async": ChatMessageHistory( "1_async": InMemoryChatMessageHistory(
messages=[ messages=[
HumanMessage(content="hello"), HumanMessage(content="hello"),
AIMessage(content="you said: hello"), AIMessage(content="you said: hello"),
@ -485,9 +487,11 @@ def test_using_custom_config_specs() -> None:
runnable = RunnableLambda(_fake_llm) runnable = RunnableLambda(_fake_llm)
store = {} store = {}
def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory: def get_session_history(
user_id: str, conversation_id: str
) -> InMemoryChatMessageHistory:
if (user_id, conversation_id) not in store: if (user_id, conversation_id) not in store:
store[(user_id, conversation_id)] = ChatMessageHistory() store[(user_id, conversation_id)] = InMemoryChatMessageHistory()
return store[(user_id, conversation_id)] return store[(user_id, conversation_id)]
with_message_history = RunnableWithMessageHistory( with_message_history = RunnableWithMessageHistory(
@ -524,7 +528,7 @@ def test_using_custom_config_specs() -> None:
AIMessage(content="you said: hello"), AIMessage(content="you said: hello"),
] ]
assert store == { assert store == {
("user1", "1"): ChatMessageHistory( ("user1", "1"): InMemoryChatMessageHistory(
messages=[ messages=[
HumanMessage(content="hello"), HumanMessage(content="hello"),
AIMessage(content="you said: hello"), AIMessage(content="you said: hello"),
@ -542,7 +546,7 @@ def test_using_custom_config_specs() -> None:
AIMessage(content="you said: goodbye"), AIMessage(content="you said: goodbye"),
] ]
assert store == { assert store == {
("user1", "1"): ChatMessageHistory( ("user1", "1"): InMemoryChatMessageHistory(
messages=[ messages=[
HumanMessage(content="hello"), HumanMessage(content="hello"),
AIMessage(content="you said: hello"), AIMessage(content="you said: hello"),
@ -562,7 +566,7 @@ def test_using_custom_config_specs() -> None:
AIMessage(content="you said: meow"), AIMessage(content="you said: meow"),
] ]
assert store == { assert store == {
("user1", "1"): ChatMessageHistory( ("user1", "1"): InMemoryChatMessageHistory(
messages=[ messages=[
HumanMessage(content="hello"), HumanMessage(content="hello"),
AIMessage(content="you said: hello"), AIMessage(content="you said: hello"),
@ -570,7 +574,7 @@ def test_using_custom_config_specs() -> None:
AIMessage(content="you said: goodbye"), AIMessage(content="you said: goodbye"),
] ]
), ),
("user2", "1"): ChatMessageHistory( ("user2", "1"): InMemoryChatMessageHistory(
messages=[ messages=[
HumanMessage(content="meow"), HumanMessage(content="meow"),
AIMessage(content="you said: meow"), AIMessage(content="you said: meow"),
@ -596,9 +600,11 @@ async def test_using_custom_config_specs_async() -> None:
runnable = RunnableLambda(_fake_llm) runnable = RunnableLambda(_fake_llm)
store = {} store = {}
def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory: def get_session_history(
user_id: str, conversation_id: str
) -> InMemoryChatMessageHistory:
if (user_id, conversation_id) not in store: if (user_id, conversation_id) not in store:
store[(user_id, conversation_id)] = ChatMessageHistory() store[(user_id, conversation_id)] = InMemoryChatMessageHistory()
return store[(user_id, conversation_id)] return store[(user_id, conversation_id)]
with_message_history = RunnableWithMessageHistory( with_message_history = RunnableWithMessageHistory(
@ -635,7 +641,7 @@ async def test_using_custom_config_specs_async() -> None:
AIMessage(content="you said: hello"), AIMessage(content="you said: hello"),
] ]
assert store == { assert store == {
("user1_async", "1_async"): ChatMessageHistory( ("user1_async", "1_async"): InMemoryChatMessageHistory(
messages=[ messages=[
HumanMessage(content="hello"), HumanMessage(content="hello"),
AIMessage(content="you said: hello"), AIMessage(content="you said: hello"),
@ -653,7 +659,7 @@ async def test_using_custom_config_specs_async() -> None:
AIMessage(content="you said: goodbye"), AIMessage(content="you said: goodbye"),
] ]
assert store == { assert store == {
("user1_async", "1_async"): ChatMessageHistory( ("user1_async", "1_async"): InMemoryChatMessageHistory(
messages=[ messages=[
HumanMessage(content="hello"), HumanMessage(content="hello"),
AIMessage(content="you said: hello"), AIMessage(content="you said: hello"),
@ -673,7 +679,7 @@ async def test_using_custom_config_specs_async() -> None:
AIMessage(content="you said: meow"), AIMessage(content="you said: meow"),
] ]
assert store == { assert store == {
("user1_async", "1_async"): ChatMessageHistory( ("user1_async", "1_async"): InMemoryChatMessageHistory(
messages=[ messages=[
HumanMessage(content="hello"), HumanMessage(content="hello"),
AIMessage(content="you said: hello"), AIMessage(content="you said: hello"),
@ -681,7 +687,7 @@ async def test_using_custom_config_specs_async() -> None:
AIMessage(content="you said: goodbye"), AIMessage(content="you said: goodbye"),
] ]
), ),
("user2_async", "1_async"): ChatMessageHistory( ("user2_async", "1_async"): InMemoryChatMessageHistory(
messages=[ messages=[
HumanMessage(content="meow"), HumanMessage(content="meow"),
AIMessage(content="you said: meow"), AIMessage(content="you said: meow"),