mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
core[patch]: Improve RunnableWithMessageHistory
init arg types (#31639)
`Runnable`'s `Input` is contravariant so we need to enumerate all possible inputs and it's not possible to put them in a `Union`. Also, it's better to only require a runnable that accepts`list[BaseMessage]` instead of a broader `Sequence[BaseMessage]` as internally the runnable is only called with a list.
This commit is contained in:
parent
dcf5c7b472
commit
b1cc972567
@ -241,7 +241,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
self,
|
||||
runnable: Union[
|
||||
Runnable[
|
||||
Union[MessagesOrDictWithMessages],
|
||||
list[BaseMessage],
|
||||
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
||||
],
|
||||
Runnable[
|
||||
dict[str, Any],
|
||||
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
||||
],
|
||||
LanguageModelLike,
|
||||
@ -258,7 +262,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
Args:
|
||||
runnable: The base Runnable to be wrapped. Must take as input one of:
|
||||
1. A sequence of BaseMessages
|
||||
1. A list of BaseMessages
|
||||
2. A dict with one key for all messages
|
||||
3. A dict with one key for the current input string/message(s) and
|
||||
a separate key for historical messages. If the input key points
|
||||
|
@ -553,7 +553,7 @@ def test_using_custom_config_specs() -> None:
|
||||
return store[(user_id, conversation_id)]
|
||||
|
||||
with_message_history = RunnableWithMessageHistory(
|
||||
runnable, # type: ignore[arg-type]
|
||||
runnable,
|
||||
get_session_history=get_session_history,
|
||||
input_messages_key="messages",
|
||||
history_messages_key="history",
|
||||
@ -666,7 +666,7 @@ async def test_using_custom_config_specs_async() -> None:
|
||||
return store[(user_id, conversation_id)]
|
||||
|
||||
with_message_history = RunnableWithMessageHistory(
|
||||
runnable, # type: ignore[arg-type]
|
||||
runnable,
|
||||
get_session_history=get_session_history,
|
||||
input_messages_key="messages",
|
||||
history_messages_key="history",
|
||||
@ -769,13 +769,13 @@ def test_ignore_session_id() -> None:
|
||||
|
||||
runnable = RunnableLambda(_fake_llm)
|
||||
history = InMemoryChatMessageHistory()
|
||||
with_message_history = RunnableWithMessageHistory(runnable, lambda: history) # type: ignore[arg-type]
|
||||
with_message_history = RunnableWithMessageHistory(runnable, lambda: history)
|
||||
_ = with_message_history.invoke("hello")
|
||||
_ = with_message_history.invoke("hello again")
|
||||
assert len(history.messages) == 4
|
||||
|
||||
|
||||
class _RunnableLambdaWithRaiseError(RunnableLambda):
|
||||
class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
|
||||
from langchain_core.tracers.root_listeners import AsyncListener
|
||||
|
||||
def with_listeners(
|
||||
@ -861,7 +861,7 @@ def test_get_output_messages_with_value_error() -> None:
|
||||
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
|
||||
store: dict = {}
|
||||
get_session_history = _get_get_session_history(store=store)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type]
|
||||
config: RunnableConfig = {
|
||||
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
|
||||
}
|
||||
@ -876,8 +876,8 @@ def test_get_output_messages_with_value_error() -> None:
|
||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||
|
||||
illegal_int_message = 123
|
||||
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
runnable2 = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
|
||||
with_history = RunnableWithMessageHistory(runnable2, get_session_history) # type: ignore[arg-type]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
|
@ -26,7 +26,6 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import (
|
||||
ConfigurableField,
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
)
|
||||
@ -1935,7 +1934,7 @@ async def test_runnable_with_message_history() -> None:
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
chain: Runnable = prompt | model
|
||||
chain = prompt | model
|
||||
with_message_history = RunnableWithMessageHistory(
|
||||
chain,
|
||||
get_session_history=get_by_session_id,
|
||||
|
@ -1890,7 +1890,7 @@ async def test_runnable_with_message_history() -> None:
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
chain: Runnable = prompt | model
|
||||
chain = prompt | model
|
||||
with_message_history = RunnableWithMessageHistory(
|
||||
chain,
|
||||
get_session_history=get_by_session_id,
|
||||
|
Loading…
Reference in New Issue
Block a user