diff --git a/libs/community/langchain_community/chat_message_histories/sql.py b/libs/community/langchain_community/chat_message_histories/sql.py index 9264dbeff0e..edfed45466f 100644 --- a/libs/community/langchain_community/chat_message_histories/sql.py +++ b/libs/community/langchain_community/chat_message_histories/sql.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import json import logging @@ -252,17 +251,11 @@ class SQLChatMessageHistory(BaseChatMessageHistory): await session.commit() def add_messages(self, messages: Sequence[BaseMessage]) -> None: - # The method RunnableWithMessageHistory._exit_history() call - # add_message method by mistake and not aadd_message. - # See https://github.com/langchain-ai/langchain/issues/22021 - if self.async_mode: - loop = asyncio.get_event_loop() - loop.run_until_complete(self.aadd_messages(messages)) - else: - with self._make_sync_session() as session: - for message in messages: - session.add(self.converter.to_sql_model(message, self.session_id)) - session.commit() + # Add all messages in one transaction + with self._make_sync_session() as session: + for message in messages: + session.add(self.converter.to_sql_model(message, self.session_id)) + session.commit() async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: # Add all messages in one transaction diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 326f941263c..05d52c38bf8 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -16,6 +16,7 @@ from typing import ( from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.load.load import load from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import RunnableBranch from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.utils import ( @@ -306,8 +307,17 @@ class RunnableWithMessageHistory(RunnableBindingBase): history_chain = RunnablePassthrough.assign( **{messages_key: history_chain} ).with_config(run_name="insert_history") - bound = ( - history_chain | runnable.with_listeners(on_end=self._exit_history) + bound: Runnable = ( + history_chain + | RunnableBranch( + ( + RunnableLambda( + self._is_not_async, afunc=self._is_async + ).with_config(run_name="RunnableWithMessageHistoryInAsyncMode"), + runnable.with_alisteners(on_end=self._aexit_history), + ), + runnable.with_listeners(on_end=self._exit_history), + ) ).with_config(run_name="RunnableWithMessageHistory") if history_factory_config: @@ -367,6 +377,12 @@ class RunnableWithMessageHistory(RunnableBindingBase): else: return super_schema + def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool: + return False + + async def _is_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool: + return True + def _get_input_messages( self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] ) -> List[BaseMessage]: @@ -483,6 +499,23 @@ class RunnableWithMessageHistory(RunnableBindingBase): output_messages = self._get_output_messages(output_val) hist.add_messages(input_messages + output_messages) + async def _aexit_history(self, run: Run, config: RunnableConfig) -> None: + hist: BaseChatMessageHistory = config["configurable"]["message_history"] + + # Get the input messages + inputs = load(run.inputs) + input_messages = self._get_input_messages(inputs) + # If historic messages were prepended to the input messages, remove them to + # avoid adding duplicate messages to history. + if not self.history_messages_key: + historic_messages = config["configurable"]["message_history"].messages + input_messages = input_messages[len(historic_messages) :] + + # Get the output messages + output_val = load(run.outputs) + output_messages = self._get_output_messages(output_val) + await hist.aadd_messages(input_messages + output_messages) + def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: config = super()._merge_configs(*configs) expected_keys = [field_spec.id for field_spec in self.history_factory_config] diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 3db18b9963f..60883f9a324 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -62,6 +62,31 @@ def test_input_messages() -> None: } +async def test_input_messages_async() -> None: + runnable = RunnableLambda( + lambda messages: "you said: " + + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) + ) + store: Dict = {} + get_session_history = _get_get_session_history(store=store) + with_history = RunnableWithMessageHistory(runnable, get_session_history) + config: RunnableConfig = {"configurable": {"session_id": "1_async"}} + output = await with_history.ainvoke([HumanMessage(content="hello")], config) + assert output == "you said: hello" + output = await with_history.ainvoke([HumanMessage(content="good bye")], config) + assert output == "you said: hello\ngood bye" + assert store == { + "1_async": ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + HumanMessage(content="good bye"), + AIMessage(content="you said: hello\ngood bye"), + ] + ) + } + + def test_input_dict() -> None: runnable = RunnableLambda( lambda input: "you said: " @@ -82,6 +107,28 @@ def test_input_dict() -> None: assert output == "you said: hello\ngood bye" +async def test_input_dict_async() -> None: + runnable = RunnableLambda( + lambda input: "you said: " + + "\n".join( + str(m.content) for m in input["messages"] if isinstance(m, HumanMessage) + ) + ) + get_session_history = _get_get_session_history() + with_history = RunnableWithMessageHistory( + runnable, get_session_history, input_messages_key="messages" + ) + config: RunnableConfig = {"configurable": {"session_id": "2_async"}} + output = await with_history.ainvoke( + {"messages": [HumanMessage(content="hello")]}, config + ) + assert output == "you said: hello" + output = await with_history.ainvoke( + {"messages": [HumanMessage(content="good bye")]}, config + ) + assert output == "you said: hello\ngood bye" + + def test_input_dict_with_history_key() -> None: runnable = RunnableLambda( lambda input: "you said: " @@ -104,6 +151,28 @@ def test_input_dict_with_history_key() -> None: assert output == "you said: hello\ngood bye" +async def test_input_dict_with_history_key_async() -> None: + runnable = RunnableLambda( + lambda input: "you said: " + + "\n".join( + [str(m.content) for m in input["history"] if isinstance(m, HumanMessage)] + + [input["input"]] + ) + ) + get_session_history = _get_get_session_history() + with_history = RunnableWithMessageHistory( + runnable, + get_session_history, + input_messages_key="input", + history_messages_key="history", + ) + config: RunnableConfig = {"configurable": {"session_id": "3_async"}} + output = await with_history.ainvoke({"input": "hello"}, config) + assert output == "you said: hello" + output = await with_history.ainvoke({"input": "good bye"}, config) + assert output == "you said: hello\ngood bye" + + def test_output_message() -> None: runnable = RunnableLambda( lambda input: AIMessage( @@ -132,41 +201,82 @@ def test_output_message() -> None: assert output == AIMessage(content="you said: hello\ngood bye") -def test_input_messages_output_message() -> None: - class LengthChatModel(BaseChatModel): - """A fake chat model that returns the length of the messages passed in.""" - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """Top Level call""" - return ChatResult( - generations=[ - ChatGeneration(message=AIMessage(content=str(len(messages)))) +async def test_output_message_async() -> None: + runnable = RunnableLambda( + lambda input: AIMessage( + content="you said: " + + "\n".join( + [ + str(m.content) + for m in input["history"] + if isinstance(m, HumanMessage) ] + + [input["input"]] ) + ) + ) + get_session_history = _get_get_session_history() + with_history = RunnableWithMessageHistory( + runnable, + get_session_history, + input_messages_key="input", + history_messages_key="history", + ) + config: RunnableConfig = {"configurable": {"session_id": "4_async"}} + output = await with_history.ainvoke({"input": "hello"}, config) + assert output == AIMessage(content="you said: hello") + output = await with_history.ainvoke({"input": "good bye"}, config) + assert output == AIMessage(content="you said: hello\ngood bye") - @property - def _llm_type(self) -> str: - return "length-fake-chat-model" +class LengthChatModel(BaseChatModel): + """A fake chat model that returns the length of the messages passed in.""" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + return ChatResult( + generations=[ChatGeneration(message=AIMessage(content=str(len(messages))))] + ) + + @property + def _llm_type(self) -> str: + return "length-fake-chat-model" + + +def test_input_messages_output_message() -> None: runnable = LengthChatModel() get_session_history = _get_get_session_history() with_history = RunnableWithMessageHistory( runnable, get_session_history, ) - config: RunnableConfig = {"configurable": {"session_id": "4"}} + config: RunnableConfig = {"configurable": {"session_id": "5"}} output = with_history.invoke([HumanMessage(content="hi")], config) assert output.content == "1" output = with_history.invoke([HumanMessage(content="hi")], config) assert output.content == "3" +async def test_input_messages_output_message_async() -> None: + runnable = LengthChatModel() + get_session_history = _get_get_session_history() + with_history = RunnableWithMessageHistory( + runnable, + get_session_history, + ) + config: RunnableConfig = {"configurable": {"session_id": "5_async"}} + output = await with_history.ainvoke([HumanMessage(content="hi")], config) + assert output.content == "1" + output = await with_history.ainvoke([HumanMessage(content="hi")], config) + assert output.content == "3" + + def test_output_messages() -> None: runnable = RunnableLambda( lambda input: [ @@ -190,13 +300,43 @@ def test_output_messages() -> None: input_messages_key="input", history_messages_key="history", ) - config: RunnableConfig = {"configurable": {"session_id": "5"}} + config: RunnableConfig = {"configurable": {"session_id": "6"}} output = with_history.invoke({"input": "hello"}, config) assert output == [AIMessage(content="you said: hello")] output = with_history.invoke({"input": "good bye"}, config) assert output == [AIMessage(content="you said: hello\ngood bye")] +async def test_output_messages_async() -> None: + runnable = RunnableLambda( + lambda input: [ + AIMessage( + content="you said: " + + "\n".join( + [ + str(m.content) + for m in input["history"] + if isinstance(m, HumanMessage) + ] + + [input["input"]] + ) + ) + ] + ) + get_session_history = _get_get_session_history() + with_history = RunnableWithMessageHistory( + runnable, # type: ignore + get_session_history, + input_messages_key="input", + history_messages_key="history", + ) + config: RunnableConfig = {"configurable": {"session_id": "6_async"}} + output = await with_history.ainvoke({"input": "hello"}, config) + assert output == [AIMessage(content="you said: hello")] + output = await with_history.ainvoke({"input": "good bye"}, config) + assert output == [AIMessage(content="you said: hello\ngood bye")] + + def test_output_dict() -> None: runnable = RunnableLambda( lambda input: { @@ -223,13 +363,46 @@ def test_output_dict() -> None: history_messages_key="history", output_messages_key="output", ) - config: RunnableConfig = {"configurable": {"session_id": "6"}} + config: RunnableConfig = {"configurable": {"session_id": "7"}} output = with_history.invoke({"input": "hello"}, config) assert output == {"output": [AIMessage(content="you said: hello")]} output = with_history.invoke({"input": "good bye"}, config) assert output == {"output": [AIMessage(content="you said: hello\ngood bye")]} +async def test_output_dict_async() -> None: + runnable = RunnableLambda( + lambda input: { + "output": [ + AIMessage( + content="you said: " + + "\n".join( + [ + str(m.content) + for m in input["history"] + if isinstance(m, HumanMessage) + ] + + [input["input"]] + ) + ) + ] + } + ) + get_session_history = _get_get_session_history() + with_history = RunnableWithMessageHistory( + runnable, + get_session_history, + input_messages_key="input", + history_messages_key="history", + output_messages_key="output", + ) + config: RunnableConfig = {"configurable": {"session_id": "7_async"}} + output = await with_history.ainvoke({"input": "hello"}, config) + assert output == {"output": [AIMessage(content="you said: hello")]} + output = await with_history.ainvoke({"input": "good bye"}, config) + assert output == {"output": [AIMessage(content="you said: hello\ngood bye")]} + + def test_get_input_schema_input_dict() -> None: class RunnableWithChatHistoryInput(BaseModel): input: Union[str, BaseMessage, Sequence[BaseMessage]] @@ -404,3 +577,114 @@ def test_using_custom_config_specs() -> None: ] ), } + + +async def test_using_custom_config_specs_async() -> None: + """Test that we can configure which keys should be passed to the session factory.""" + + def _fake_llm(input: Dict[str, Any]) -> List[BaseMessage]: + messages = input["messages"] + return [ + AIMessage( + content="you said: " + + "\n".join( + str(m.content) for m in messages if isinstance(m, HumanMessage) + ) + ) + ] + + runnable = RunnableLambda(_fake_llm) + store = {} + + def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory: + if (user_id, conversation_id) not in store: + store[(user_id, conversation_id)] = ChatMessageHistory() + return store[(user_id, conversation_id)] + + with_message_history = RunnableWithMessageHistory( + runnable, # type: ignore + get_session_history=get_session_history, + input_messages_key="messages", + history_messages_key="history", + history_factory_config=[ + ConfigurableFieldSpec( + id="user_id", + annotation=str, + name="User ID", + description="Unique identifier for the user.", + default="", + is_shared=True, + ), + ConfigurableFieldSpec( + id="conversation_id", + annotation=str, + name="Conversation ID", + description="Unique identifier for the conversation.", + default=None, + is_shared=True, + ), + ], + ) + result = await with_message_history.ainvoke( + { + "messages": [HumanMessage(content="hello")], + }, + {"configurable": {"user_id": "user1_async", "conversation_id": "1_async"}}, + ) + assert result == [ + AIMessage(content="you said: hello"), + ] + assert store == { + ("user1_async", "1_async"): ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + ] + ) + } + + result = await with_message_history.ainvoke( + { + "messages": [HumanMessage(content="goodbye")], + }, + {"configurable": {"user_id": "user1_async", "conversation_id": "1_async"}}, + ) + assert result == [ + AIMessage(content="you said: goodbye"), + ] + assert store == { + ("user1_async", "1_async"): ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + HumanMessage(content="goodbye"), + AIMessage(content="you said: goodbye"), + ] + ) + } + + result = await with_message_history.ainvoke( + { + "messages": [HumanMessage(content="meow")], + }, + {"configurable": {"user_id": "user2_async", "conversation_id": "1_async"}}, + ) + assert result == [ + AIMessage(content="you said: meow"), + ] + assert store == { + ("user1_async", "1_async"): ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + HumanMessage(content="goodbye"), + AIMessage(content="you said: goodbye"), + ] + ), + ("user2_async", "1_async"): ChatMessageHistory( + messages=[ + HumanMessage(content="meow"), + AIMessage(content="you said: meow"), + ] + ), + }