diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 4ca57d0cbaa..c1b6b7f94c2 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -1,13 +1,17 @@ from __future__ import annotations import asyncio +import inspect from typing import ( TYPE_CHECKING, Any, + Callable, + Dict, List, Optional, Sequence, Type, + Union, ) from langchain_core.chat_history import BaseChatMessageHistory @@ -25,8 +29,6 @@ if TYPE_CHECKING: from langchain_core.runnables.config import RunnableConfig from langchain_core.tracers.schemas import Run -import inspect -from typing import Callable, Dict, Union MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]] GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] @@ -35,13 +37,9 @@ GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] class RunnableWithMessageHistory(RunnableBindingBase): """A runnable that manages chat message history for another runnable. - Base runnable must have inputs and outputs that can be converted to a list of - BaseMessages. + Base runnable must have inputs and outputs that can be converted to a list of BaseMessages. - RunnableWithMessageHistory must always be called with a config that contains - session_id, e.g.: - - ``{"configurable": {"session_id": ""}}` + RunnableWithMessageHistory must always be called with a config that contains session_id, e.g. ``{"configurable": {"session_id": ""}}`. Example (dict input): .. code-block:: python @@ -82,9 +80,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # -> "The inverse of cosine is called arccosine ..." - Here's an example that uses an in memory chat history, and a factory that - takes in two keys (user_id and conversation id) to create a chat history instance. - + Example (get_session_history takes two keys, user_id and conversation id): .. code-block:: python store = {} @@ -164,46 +160,43 @@ class RunnableWithMessageHistory(RunnableBindingBase): """Initialize RunnableWithMessageHistory. Args: - runnable: The base Runnable to be wrapped. - - Must take as input one of: - - A sequence of BaseMessages - - A dict with one key for all messages - - A dict with one key for the current input string/message(s) and + runnable: The base Runnable to be wrapped. Must take as input one of: + 1. A sequence of BaseMessages + 2. A dict with one key for all messages + 3. A dict with one key for the current input string/message(s) and a separate key for historical messages. If the input key points to a string, it will be treated as a HumanMessage in history. Must return as output one of: - - A string which can be treated as an AIMessage - - A BaseMessage or sequence of BaseMessages - - A dict with a key for a BaseMessage or sequence of BaseMessages + 1. A string which can be treated as an AIMessage + 2. A BaseMessage or sequence of BaseMessages + 3. A dict with a key for a BaseMessage or sequence of BaseMessages get_session_history: Function that returns a new BaseChatMessageHistory. This function should either take a single positional argument `session_id` of type string and return a corresponding chat message history instance. + .. code-block:: python - ```python - def get_session_history( - session_id: str, - *, - user_id: Optional[str]=None - ) -> BaseChatMessageHistory: - ... - ``` + def get_session_history( + session_id: str, + *, + user_id: Optional[str]=None + ) -> BaseChatMessageHistory: + ... Or it should take keyword arguments that match the keys of `session_history_config_specs` and return a corresponding chat message history instance. - ```python - def get_session_history( - *, - user_id: str, - thread_id: str, - ) -> BaseChatMessageHistory: - ... - ``` + .. code-block:: python + + def get_session_history( + *, + user_id: str, + thread_id: str, + ) -> BaseChatMessageHistory: + ... input_messages_key: Must be specified if the base runnable accepts a dict as input. @@ -350,6 +343,12 @@ class RunnableWithMessageHistory(RunnableBindingBase): input_val = inputs[self.input_messages_key or "input"] input_messages = self._get_input_messages(input_val) + # If historic messages were prepended to the input messages, remove them to + # avoid adding duplicate messages to history. + if not self.history_messages_key: + historic_messages = config["configurable"]["message_history"].messages + input_messages = input_messages[len(historic_messages) :] + # Get the output messages output_val = load(run.outputs) output_messages = self._get_output_messages(output_val) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 0bec016248b..193a779021f 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.pydantic_v1 import BaseModel @@ -18,8 +18,10 @@ def test_interfaces() -> None: assert str(history) == "System: system\nHuman: human 1\nAI: ai\nHuman: human 2" -def _get_get_session_history() -> Callable[..., ChatMessageHistory]: - chat_history_store = {} +def _get_get_session_history( + *, store: Optional[Dict[str, Any]] = None +) -> Callable[..., ChatMessageHistory]: + chat_history_store = store if store is not None else {} def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory: if session_id not in chat_history_store: @@ -34,13 +36,24 @@ def test_input_messages() -> None: lambda messages: "you said: " + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) ) - get_session_history = _get_get_session_history() + store: Dict = {} + get_session_history = _get_get_session_history(store=store) with_history = RunnableWithMessageHistory(runnable, get_session_history) config: RunnableConfig = {"configurable": {"session_id": "1"}} output = with_history.invoke([HumanMessage(content="hello")], config) assert output == "you said: hello" output = with_history.invoke([HumanMessage(content="good bye")], config) assert output == "you said: hello\ngood bye" + assert store == { + "1": ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + HumanMessage(content="good bye"), + AIMessage(content="you said: hello\ngood bye"), + ] + ) + } def test_input_dict() -> None: