Compare commits

...

3 Commits

Author SHA1 Message Date
Bagatur
df704e1f4f Merge branch 'master' into eugene/message_history_test 2024-01-30 09:54:08 -08:00
Eugene Yurtsev
607a171a74 x 2024-01-05 21:42:15 -05:00
Eugene Yurtsev
c2da71982c x 2024-01-05 21:40:36 -05:00

View File

@@ -364,3 +364,47 @@ def test_using_custom_config_specs() -> None:
]
),
}
def test_thingy():
from langchain.schema import HumanMessage
from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
def _fake_llm(input: Dict[str, Any]) -> List[BaseMessage]:
messages = input
return [
AIMessage(
content="you said: "
+ "\n".join(
str(m.content) for m in messages if isinstance(m, HumanMessage)
)
)
]
chat = RunnableLambda(_fake_llm)
def factory(session_id):
return SQLChatMessageHistory(
session_id=session_id, connection_string="sqlite:///sqlite.db"
)
session = factory("abc")
session.clear()
chat_with_history = RunnableWithMessageHistory(
chat,
factory,
)
# This is where we configure the session id, which is needed for fetching messages
config = {"configurable": {"session_id": "abc"}}
chat_with_history.invoke(HumanMessage(content="Hi! I'm Bob"), config=config)
chat_with_history.invoke(HumanMessage(content="Hi! I'm Alice"), config=config)
assert session.messages == [
HumanMessage(content="Hi! I'm Bob"),
AIMessage(content="you said: Hi! I'm Bob"),
HumanMessage(content="Hi! I'm Alice"),
AIMessage(content="you said: Hi! I'm Bob\nHi! I'm Alice"),
]