core[patch]: Improve RunnableWithMessageHistory init arg types (#31639)

`Runnable`'s `Input` is contravariant so we need to enumerate all
possible inputs and it's not possible to put them in a `Union`.
Also, it's better to only require a runnable that
accepts`list[BaseMessage]` instead of a broader `Sequence[BaseMessage]`
as internally the runnable is only called with a list.
This commit is contained in:
Christophe Bornet 2025-06-23 19:45:52 +02:00 committed by GitHub
parent dcf5c7b472
commit b1cc972567
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 12 deletions

View File

@ -241,7 +241,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
self,
runnable: Union[
Runnable[
Union[MessagesOrDictWithMessages],
list[BaseMessage],
Union[str, BaseMessage, MessagesOrDictWithMessages],
],
Runnable[
dict[str, Any],
Union[str, BaseMessage, MessagesOrDictWithMessages],
],
LanguageModelLike,
@ -258,7 +262,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
Args:
runnable: The base Runnable to be wrapped. Must take as input one of:
1. A sequence of BaseMessages
1. A list of BaseMessages
2. A dict with one key for all messages
3. A dict with one key for the current input string/message(s) and
a separate key for historical messages. If the input key points

View File

@ -553,7 +553,7 @@ def test_using_custom_config_specs() -> None:
return store[(user_id, conversation_id)]
with_message_history = RunnableWithMessageHistory(
runnable, # type: ignore[arg-type]
runnable,
get_session_history=get_session_history,
input_messages_key="messages",
history_messages_key="history",
@ -666,7 +666,7 @@ async def test_using_custom_config_specs_async() -> None:
return store[(user_id, conversation_id)]
with_message_history = RunnableWithMessageHistory(
runnable, # type: ignore[arg-type]
runnable,
get_session_history=get_session_history,
input_messages_key="messages",
history_messages_key="history",
@ -769,13 +769,13 @@ def test_ignore_session_id() -> None:
runnable = RunnableLambda(_fake_llm)
history = InMemoryChatMessageHistory()
with_message_history = RunnableWithMessageHistory(runnable, lambda: history) # type: ignore[arg-type]
with_message_history = RunnableWithMessageHistory(runnable, lambda: history)
_ = with_message_history.invoke("hello")
_ = with_message_history.invoke("hello again")
assert len(history.messages) == 4
class _RunnableLambdaWithRaiseError(RunnableLambda):
class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
from langchain_core.tracers.root_listeners import AsyncListener
def with_listeners(
@ -861,7 +861,7 @@ def test_get_output_messages_with_value_error() -> None:
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
store: dict = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type]
config: RunnableConfig = {
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
}
@ -876,8 +876,8 @@ def test_get_output_messages_with_value_error() -> None:
with_history.bound.invoke([HumanMessage(content="hello")], config)
illegal_int_message = 123
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
runnable2 = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
with_history = RunnableWithMessageHistory(runnable2, get_session_history) # type: ignore[arg-type]
with pytest.raises(
ValueError,

View File

@ -26,7 +26,6 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
ConfigurableField,
Runnable,
RunnableConfig,
RunnableLambda,
)
@ -1935,7 +1934,7 @@ async def test_runnable_with_message_history() -> None:
)
model = GenericFakeChatModel(messages=infinite_cycle)
chain: Runnable = prompt | model
chain = prompt | model
with_message_history = RunnableWithMessageHistory(
chain,
get_session_history=get_by_session_id,

View File

@ -1890,7 +1890,7 @@ async def test_runnable_with_message_history() -> None:
)
model = GenericFakeChatModel(messages=infinite_cycle)
chain: Runnable = prompt | model
chain = prompt | model
with_message_history = RunnableWithMessageHistory(
chain,
get_session_history=get_by_session_id,