mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 17:08:47 +00:00
core[patch]: Improve RunnableWithMessageHistory
init arg types (#31639)
`Runnable`'s `Input` is contravariant so we need to enumerate all possible inputs and it's not possible to put them in a `Union`. Also, it's better to only require a runnable that accepts`list[BaseMessage]` instead of a broader `Sequence[BaseMessage]` as internally the runnable is only called with a list.
This commit is contained in:
parent
dcf5c7b472
commit
b1cc972567
@ -241,7 +241,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
self,
|
self,
|
||||||
runnable: Union[
|
runnable: Union[
|
||||||
Runnable[
|
Runnable[
|
||||||
Union[MessagesOrDictWithMessages],
|
list[BaseMessage],
|
||||||
|
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
||||||
|
],
|
||||||
|
Runnable[
|
||||||
|
dict[str, Any],
|
||||||
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
||||||
],
|
],
|
||||||
LanguageModelLike,
|
LanguageModelLike,
|
||||||
@ -258,7 +262,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
runnable: The base Runnable to be wrapped. Must take as input one of:
|
runnable: The base Runnable to be wrapped. Must take as input one of:
|
||||||
1. A sequence of BaseMessages
|
1. A list of BaseMessages
|
||||||
2. A dict with one key for all messages
|
2. A dict with one key for all messages
|
||||||
3. A dict with one key for the current input string/message(s) and
|
3. A dict with one key for the current input string/message(s) and
|
||||||
a separate key for historical messages. If the input key points
|
a separate key for historical messages. If the input key points
|
||||||
|
@ -553,7 +553,7 @@ def test_using_custom_config_specs() -> None:
|
|||||||
return store[(user_id, conversation_id)]
|
return store[(user_id, conversation_id)]
|
||||||
|
|
||||||
with_message_history = RunnableWithMessageHistory(
|
with_message_history = RunnableWithMessageHistory(
|
||||||
runnable, # type: ignore[arg-type]
|
runnable,
|
||||||
get_session_history=get_session_history,
|
get_session_history=get_session_history,
|
||||||
input_messages_key="messages",
|
input_messages_key="messages",
|
||||||
history_messages_key="history",
|
history_messages_key="history",
|
||||||
@ -666,7 +666,7 @@ async def test_using_custom_config_specs_async() -> None:
|
|||||||
return store[(user_id, conversation_id)]
|
return store[(user_id, conversation_id)]
|
||||||
|
|
||||||
with_message_history = RunnableWithMessageHistory(
|
with_message_history = RunnableWithMessageHistory(
|
||||||
runnable, # type: ignore[arg-type]
|
runnable,
|
||||||
get_session_history=get_session_history,
|
get_session_history=get_session_history,
|
||||||
input_messages_key="messages",
|
input_messages_key="messages",
|
||||||
history_messages_key="history",
|
history_messages_key="history",
|
||||||
@ -769,13 +769,13 @@ def test_ignore_session_id() -> None:
|
|||||||
|
|
||||||
runnable = RunnableLambda(_fake_llm)
|
runnable = RunnableLambda(_fake_llm)
|
||||||
history = InMemoryChatMessageHistory()
|
history = InMemoryChatMessageHistory()
|
||||||
with_message_history = RunnableWithMessageHistory(runnable, lambda: history) # type: ignore[arg-type]
|
with_message_history = RunnableWithMessageHistory(runnable, lambda: history)
|
||||||
_ = with_message_history.invoke("hello")
|
_ = with_message_history.invoke("hello")
|
||||||
_ = with_message_history.invoke("hello again")
|
_ = with_message_history.invoke("hello again")
|
||||||
assert len(history.messages) == 4
|
assert len(history.messages) == 4
|
||||||
|
|
||||||
|
|
||||||
class _RunnableLambdaWithRaiseError(RunnableLambda):
|
class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]):
|
||||||
from langchain_core.tracers.root_listeners import AsyncListener
|
from langchain_core.tracers.root_listeners import AsyncListener
|
||||||
|
|
||||||
def with_listeners(
|
def with_listeners(
|
||||||
@ -861,7 +861,7 @@ def test_get_output_messages_with_value_error() -> None:
|
|||||||
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
|
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
|
||||||
store: dict = {}
|
store: dict = {}
|
||||||
get_session_history = _get_get_session_history(store=store)
|
get_session_history = _get_get_session_history(store=store)
|
||||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type]
|
||||||
config: RunnableConfig = {
|
config: RunnableConfig = {
|
||||||
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
|
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
|
||||||
}
|
}
|
||||||
@ -876,8 +876,8 @@ def test_get_output_messages_with_value_error() -> None:
|
|||||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||||
|
|
||||||
illegal_int_message = 123
|
illegal_int_message = 123
|
||||||
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
|
runnable2 = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
|
||||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
with_history = RunnableWithMessageHistory(runnable2, get_session_history) # type: ignore[arg-type]
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
|
@ -26,7 +26,6 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
ConfigurableField,
|
ConfigurableField,
|
||||||
Runnable,
|
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
)
|
)
|
||||||
@ -1935,7 +1934,7 @@ async def test_runnable_with_message_history() -> None:
|
|||||||
)
|
)
|
||||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
|
|
||||||
chain: Runnable = prompt | model
|
chain = prompt | model
|
||||||
with_message_history = RunnableWithMessageHistory(
|
with_message_history = RunnableWithMessageHistory(
|
||||||
chain,
|
chain,
|
||||||
get_session_history=get_by_session_id,
|
get_session_history=get_by_session_id,
|
||||||
|
@ -1890,7 +1890,7 @@ async def test_runnable_with_message_history() -> None:
|
|||||||
)
|
)
|
||||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
|
|
||||||
chain: Runnable = prompt | model
|
chain = prompt | model
|
||||||
with_message_history = RunnableWithMessageHistory(
|
with_message_history = RunnableWithMessageHistory(
|
||||||
chain,
|
chain,
|
||||||
get_session_history=get_by_session_id,
|
get_session_history=get_by_session_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user