From b1cc972567490161984d78b9a42e216162cce193 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 23 Jun 2025 19:45:52 +0200 Subject: [PATCH] core[patch]: Improve `RunnableWithMessageHistory` init arg types (#31639) `Runnable`'s `Input` is contravariant so we need to enumerate all possible inputs and it's not possible to put them in a `Union`. Also, it's better to only require a runnable that accepts`list[BaseMessage]` instead of a broader `Sequence[BaseMessage]` as internally the runnable is only called with a list. --- libs/core/langchain_core/runnables/history.py | 8 ++++++-- .../tests/unit_tests/runnables/test_history.py | 14 +++++++------- .../runnables/test_runnable_events_v1.py | 3 +-- .../runnables/test_runnable_events_v2.py | 2 +- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 3a467811b4c..53288b5160e 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -241,7 +241,11 @@ class RunnableWithMessageHistory(RunnableBindingBase): self, runnable: Union[ Runnable[ - Union[MessagesOrDictWithMessages], + list[BaseMessage], + Union[str, BaseMessage, MessagesOrDictWithMessages], + ], + Runnable[ + dict[str, Any], Union[str, BaseMessage, MessagesOrDictWithMessages], ], LanguageModelLike, @@ -258,7 +262,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): Args: runnable: The base Runnable to be wrapped. Must take as input one of: - 1. A sequence of BaseMessages + 1. A list of BaseMessages 2. A dict with one key for all messages 3. A dict with one key for the current input string/message(s) and a separate key for historical messages. If the input key points diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 1807d8b3d5c..5db5957435d 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -553,7 +553,7 @@ def test_using_custom_config_specs() -> None: return store[(user_id, conversation_id)] with_message_history = RunnableWithMessageHistory( - runnable, # type: ignore[arg-type] + runnable, get_session_history=get_session_history, input_messages_key="messages", history_messages_key="history", @@ -666,7 +666,7 @@ async def test_using_custom_config_specs_async() -> None: return store[(user_id, conversation_id)] with_message_history = RunnableWithMessageHistory( - runnable, # type: ignore[arg-type] + runnable, get_session_history=get_session_history, input_messages_key="messages", history_messages_key="history", @@ -769,13 +769,13 @@ def test_ignore_session_id() -> None: runnable = RunnableLambda(_fake_llm) history = InMemoryChatMessageHistory() - with_message_history = RunnableWithMessageHistory(runnable, lambda: history) # type: ignore[arg-type] + with_message_history = RunnableWithMessageHistory(runnable, lambda: history) _ = with_message_history.invoke("hello") _ = with_message_history.invoke("hello again") assert len(history.messages) == 4 -class _RunnableLambdaWithRaiseError(RunnableLambda): +class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]): from langchain_core.tracers.root_listeners import AsyncListener def with_listeners( @@ -861,7 +861,7 @@ def test_get_output_messages_with_value_error() -> None: runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message) store: dict = {} get_session_history = _get_get_session_history(store=store) - with_history = RunnableWithMessageHistory(runnable, get_session_history) + with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type] config: RunnableConfig = { "configurable": {"session_id": "1", "message_history": get_session_history("1")} } @@ -876,8 +876,8 @@ def test_get_output_messages_with_value_error() -> None: with_history.bound.invoke([HumanMessage(content="hello")], config) illegal_int_message = 123 - runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message) - with_history = RunnableWithMessageHistory(runnable, get_session_history) + runnable2 = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message) + with_history = RunnableWithMessageHistory(runnable2, get_session_history) # type: ignore[arg-type] with pytest.raises( ValueError, diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 327f96224b3..b41409754a3 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -26,7 +26,6 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import ( ConfigurableField, - Runnable, RunnableConfig, RunnableLambda, ) @@ -1935,7 +1934,7 @@ async def test_runnable_with_message_history() -> None: ) model = GenericFakeChatModel(messages=infinite_cycle) - chain: Runnable = prompt | model + chain = prompt | model with_message_history = RunnableWithMessageHistory( chain, get_session_history=get_by_session_id, diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index be384668edb..ea7fe95ba5b 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -1890,7 +1890,7 @@ async def test_runnable_with_message_history() -> None: ) model = GenericFakeChatModel(messages=infinite_cycle) - chain: Runnable = prompt | model + chain = prompt | model with_message_history = RunnableWithMessageHistory( chain, get_session_history=get_by_session_id,