core[minor]: Use BaseChatMessageHistory async methods in RunnableWithMessageHistory (#19565)

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2024-03-26 15:13:58 +01:00 committed by GitHub
parent 8595c3ab59
commit 1f422318b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 11 deletions

View File

@ -17,7 +17,6 @@ from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load from langchain_core.load.load import load
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.config import run_in_executor
from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
ConfigurableFieldSpec, ConfigurableFieldSpec,
@ -403,21 +402,30 @@ class RunnableWithMessageHistory(RunnableBindingBase):
raise ValueError() raise ValueError()
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]: def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
hist = config["configurable"]["message_history"] hist: BaseChatMessageHistory = config["configurable"]["message_history"]
# return only historic messages messages = hist.messages.copy()
if self.history_messages_key:
return hist.messages.copy() if not self.history_messages_key:
# return all messages # return all messages
else:
input_val = ( input_val = (
input if not self.input_messages_key else input[self.input_messages_key] input if not self.input_messages_key else input[self.input_messages_key]
) )
return hist.messages.copy() + self._get_input_messages(input_val) messages += self._get_input_messages(input_val)
return messages
async def _aenter_history( async def _aenter_history(
self, input: Dict[str, Any], config: RunnableConfig self, input: Dict[str, Any], config: RunnableConfig
) -> List[BaseMessage]: ) -> List[BaseMessage]:
return await run_in_executor(config, self._enter_history, input, config) hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = (await hist.aget_messages()).copy()
if not self.history_messages_key:
# return all messages
input_val = (
input if not self.input_messages_key else input[self.input_messages_key]
)
messages += self._get_input_messages(input_val)
return messages
def _exit_history(self, run: Run, config: RunnableConfig) -> None: def _exit_history(self, run: Run, config: RunnableConfig) -> None:
hist: BaseChatMessageHistory = config["configurable"]["message_history"] hist: BaseChatMessageHistory = config["configurable"]["message_history"]

View File

@ -1254,9 +1254,9 @@ async def test_runnable_with_message_history() -> None:
input_messages_key="question", input_messages_key="question",
history_messages_key="history", history_messages_key="history",
) )
with_message_history.with_config( await with_message_history.with_config(
{"configurable": {"session_id": "session-123"}} {"configurable": {"session_id": "session-123"}}
).invoke({"question": "hello"}) ).ainvoke({"question": "hello"})
assert store == { assert store == {
"session-123": [HumanMessage(content="hello"), AIMessage(content="hello")] "session-123": [HumanMessage(content="hello"), AIMessage(content="hello")]