From 1f422318b71348d2497fcb3aab6a6548b91102a0 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 26 Mar 2024 15:13:58 +0100 Subject: [PATCH] core[minor]: Use BaseChatMessageHistory async methods in RunnableWithMessageHistory (#19565) Co-authored-by: Eugene Yurtsev --- libs/core/langchain_core/runnables/history.py | 26 ++++++++++++------- .../runnables/test_runnable_events.py | 4 +-- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 6d2dbe33bc9..f02364bccde 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -17,7 +17,6 @@ from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.load.load import load from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda -from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.utils import ( ConfigurableFieldSpec, @@ -403,21 +402,30 @@ class RunnableWithMessageHistory(RunnableBindingBase): raise ValueError() def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]: - hist = config["configurable"]["message_history"] - # return only historic messages - if self.history_messages_key: - return hist.messages.copy() - # return all messages - else: + hist: BaseChatMessageHistory = config["configurable"]["message_history"] + messages = hist.messages.copy() + + if not self.history_messages_key: + # return all messages input_val = ( input if not self.input_messages_key else input[self.input_messages_key] ) - return hist.messages.copy() + self._get_input_messages(input_val) + messages += self._get_input_messages(input_val) + return messages async def _aenter_history( self, input: Dict[str, Any], config: RunnableConfig ) -> List[BaseMessage]: - return await run_in_executor(config, self._enter_history, input, config) + hist: BaseChatMessageHistory = config["configurable"]["message_history"] + messages = (await hist.aget_messages()).copy() + + if not self.history_messages_key: + # return all messages + input_val = ( + input if not self.input_messages_key else input[self.input_messages_key] + ) + messages += self._get_input_messages(input_val) + return messages def _exit_history(self, run: Run, config: RunnableConfig) -> None: hist: BaseChatMessageHistory = config["configurable"]["message_history"] diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events.py b/libs/core/tests/unit_tests/runnables/test_runnable_events.py index d2d76d87799..3b822bce6fe 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events.py @@ -1254,9 +1254,9 @@ async def test_runnable_with_message_history() -> None: input_messages_key="question", history_messages_key="history", ) - with_message_history.with_config( + await with_message_history.with_config( {"configurable": {"session_id": "session-123"}} - ).invoke({"question": "hello"}) + ).ainvoke({"question": "hello"}) assert store == { "session-123": [HumanMessage(content="hello"), AIMessage(content="hello")]