core[patch]: Improve RunnableWithMessageHistory init arg types (#31639)

`Runnable`'s `Input` is contravariant so we need to enumerate all
possible inputs and it's not possible to put them in a `Union`.
Also, it's better to only require a runnable that
accepts`list[BaseMessage]` instead of a broader `Sequence[BaseMessage]`
as internally the runnable is only called with a list.
This commit is contained in:
Christophe Bornet 2025-06-23 19:45:52 +02:00 committed by GitHub
parent dcf5c7b472
commit b1cc972567
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 12 deletions

View File

@ -241,7 +241,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
self, self,
runnable: Union[ runnable: Union[
Runnable[ Runnable[
Union[MessagesOrDictWithMessages], list[BaseMessage],
Union[str, BaseMessage, MessagesOrDictWithMessages],
],
Runnable[
dict[str, Any],
Union[str, BaseMessage, MessagesOrDictWithMessages], Union[str, BaseMessage, MessagesOrDictWithMessages],
], ],
LanguageModelLike, LanguageModelLike,
@ -258,7 +262,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
Args: Args:
runnable: The base Runnable to be wrapped. Must take as input one of: runnable: The base Runnable to be wrapped. Must take as input one of:
1. A sequence of BaseMessages 1. A list of BaseMessages
2. A dict with one key for all messages 2. A dict with one key for all messages
3. A dict with one key for the current input string/message(s) and 3. A dict with one key for the current input string/message(s) and
a separate key for historical messages. If the input key points a separate key for historical messages. If the input key points

View File

@ -553,7 +553,7 @@ def test_using_custom_config_specs() -> None:
return store[(user_id, conversation_id)] return store[(user_id, conversation_id)]
with_message_history = RunnableWithMessageHistory( with_message_history = RunnableWithMessageHistory(
runnable, # type: ignore[arg-type] runnable,
get_session_history=get_session_history, get_session_history=get_session_history,
input_messages_key="messages", input_messages_key="messages",
history_messages_key="history", history_messages_key="history",
@ -666,7 +666,7 @@ async def test_using_custom_config_specs_async() -> None:
return store[(user_id, conversation_id)] return store[(user_id, conversation_id)]
with_message_history = RunnableWithMessageHistory( with_message_history = RunnableWithMessageHistory(
runnable, # type: ignore[arg-type] runnable,
get_session_history=get_session_history, get_session_history=get_session_history,
input_messages_key="messages", input_messages_key="messages",
history_messages_key="history", history_messages_key="history",
@ -769,13 +769,13 @@ def test_ignore_session_id() -> None:
runnable = RunnableLambda(_fake_llm) runnable = RunnableLambda(_fake_llm)
history = InMemoryChatMessageHistory() history = InMemoryChatMessageHistory()
with_message_history = RunnableWithMessageHistory(runnable, lambda: history) # type: ignore[arg-type] with_message_history = RunnableWithMessageHistory(runnable, lambda: history)
_ = with_message_history.invoke("hello") _ = with_message_history.invoke("hello")
_ = with_message_history.invoke("hello again") _ = with_message_history.invoke("hello again")
assert len(history.messages) == 4 assert len(history.messages) == 4
class _RunnableLambdaWithRaiseError(RunnableLambda): class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
from langchain_core.tracers.root_listeners import AsyncListener from langchain_core.tracers.root_listeners import AsyncListener
def with_listeners( def with_listeners(
@ -861,7 +861,7 @@ def test_get_output_messages_with_value_error() -> None:
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message) runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
store: dict = {} store: dict = {}
get_session_history = _get_get_session_history(store=store) get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history) with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type]
config: RunnableConfig = { config: RunnableConfig = {
"configurable": {"session_id": "1", "message_history": get_session_history("1")} "configurable": {"session_id": "1", "message_history": get_session_history("1")}
} }
@ -876,8 +876,8 @@ def test_get_output_messages_with_value_error() -> None:
with_history.bound.invoke([HumanMessage(content="hello")], config) with_history.bound.invoke([HumanMessage(content="hello")], config)
illegal_int_message = 123 illegal_int_message = 123
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message) runnable2 = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
with_history = RunnableWithMessageHistory(runnable, get_session_history) with_history = RunnableWithMessageHistory(runnable2, get_session_history) # type: ignore[arg-type]
with pytest.raises( with pytest.raises(
ValueError, ValueError,

View File

@ -26,7 +26,6 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import ( from langchain_core.runnables import (
ConfigurableField, ConfigurableField,
Runnable,
RunnableConfig, RunnableConfig,
RunnableLambda, RunnableLambda,
) )
@ -1935,7 +1934,7 @@ async def test_runnable_with_message_history() -> None:
) )
model = GenericFakeChatModel(messages=infinite_cycle) model = GenericFakeChatModel(messages=infinite_cycle)
chain: Runnable = prompt | model chain = prompt | model
with_message_history = RunnableWithMessageHistory( with_message_history = RunnableWithMessageHistory(
chain, chain,
get_session_history=get_by_session_id, get_session_history=get_by_session_id,

View File

@ -1890,7 +1890,7 @@ async def test_runnable_with_message_history() -> None:
) )
model = GenericFakeChatModel(messages=infinite_cycle) model = GenericFakeChatModel(messages=infinite_cycle)
chain: Runnable = prompt | model chain = prompt | model
with_message_history = RunnableWithMessageHistory( with_message_history = RunnableWithMessageHistory(
chain, chain,
get_session_history=get_by_session_id, get_session_history=get_by_session_id,