core[patch]: Fix runnable with message history (#14629)

Fix bug shown in #14458. Namely, that saving inputs to history fails
when the input to base runnable is a list of messages
This commit is contained in:
Bagatur 2023-12-13 14:25:35 -08:00 committed by GitHub
parent 99743539ae
commit 47451951a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 40 deletions

View File

@ -1,13 +1,17 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import inspect
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable,
Dict,
List, List,
Optional, Optional,
Sequence, Sequence,
Type, Type,
Union,
) )
from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory
@ -25,8 +29,6 @@ if TYPE_CHECKING:
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
import inspect
from typing import Callable, Dict, Union
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]] MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
@ -35,13 +37,9 @@ GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
class RunnableWithMessageHistory(RunnableBindingBase): class RunnableWithMessageHistory(RunnableBindingBase):
"""A runnable that manages chat message history for another runnable. """A runnable that manages chat message history for another runnable.
Base runnable must have inputs and outputs that can be converted to a list of Base runnable must have inputs and outputs that can be converted to a list of BaseMessages.
BaseMessages.
RunnableWithMessageHistory must always be called with a config that contains RunnableWithMessageHistory must always be called with a config that contains session_id, e.g. ``{"configurable": {"session_id": "<SESSION_ID>"}}`.
session_id, e.g.:
``{"configurable": {"session_id": "<SESSION_ID>"}}`
Example (dict input): Example (dict input):
.. code-block:: python .. code-block:: python
@ -82,9 +80,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
# -> "The inverse of cosine is called arccosine ..." # -> "The inverse of cosine is called arccosine ..."
Here's an example that uses an in memory chat history, and a factory that Example (get_session_history takes two keys, user_id and conversation id):
takes in two keys (user_id and conversation id) to create a chat history instance.
.. code-block:: python .. code-block:: python
store = {} store = {}
@ -164,46 +160,43 @@ class RunnableWithMessageHistory(RunnableBindingBase):
"""Initialize RunnableWithMessageHistory. """Initialize RunnableWithMessageHistory.
Args: Args:
runnable: The base Runnable to be wrapped. runnable: The base Runnable to be wrapped. Must take as input one of:
1. A sequence of BaseMessages
Must take as input one of: 2. A dict with one key for all messages
- A sequence of BaseMessages 3. A dict with one key for the current input string/message(s) and
- A dict with one key for all messages
- A dict with one key for the current input string/message(s) and
a separate key for historical messages. If the input key points a separate key for historical messages. If the input key points
to a string, it will be treated as a HumanMessage in history. to a string, it will be treated as a HumanMessage in history.
Must return as output one of: Must return as output one of:
- A string which can be treated as an AIMessage 1. A string which can be treated as an AIMessage
- A BaseMessage or sequence of BaseMessages 2. A BaseMessage or sequence of BaseMessages
- A dict with a key for a BaseMessage or sequence of BaseMessages 3. A dict with a key for a BaseMessage or sequence of BaseMessages
get_session_history: Function that returns a new BaseChatMessageHistory. get_session_history: Function that returns a new BaseChatMessageHistory.
This function should either take a single positional argument This function should either take a single positional argument
`session_id` of type string and return a corresponding `session_id` of type string and return a corresponding
chat message history instance. chat message history instance.
.. code-block:: python
```python
def get_session_history( def get_session_history(
session_id: str, session_id: str,
*, *,
user_id: Optional[str]=None user_id: Optional[str]=None
) -> BaseChatMessageHistory: ) -> BaseChatMessageHistory:
... ...
```
Or it should take keyword arguments that match the keys of Or it should take keyword arguments that match the keys of
`session_history_config_specs` and return a corresponding `session_history_config_specs` and return a corresponding
chat message history instance. chat message history instance.
```python .. code-block:: python
def get_session_history( def get_session_history(
*, *,
user_id: str, user_id: str,
thread_id: str, thread_id: str,
) -> BaseChatMessageHistory: ) -> BaseChatMessageHistory:
... ...
```
input_messages_key: Must be specified if the base runnable accepts a dict input_messages_key: Must be specified if the base runnable accepts a dict
as input. as input.
@ -350,6 +343,12 @@ class RunnableWithMessageHistory(RunnableBindingBase):
input_val = inputs[self.input_messages_key or "input"] input_val = inputs[self.input_messages_key or "input"]
input_messages = self._get_input_messages(input_val) input_messages = self._get_input_messages(input_val)
# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
if not self.history_messages_key:
historic_messages = config["configurable"]["message_history"].messages
input_messages = input_messages[len(historic_messages) :]
# Get the output messages # Get the output messages
output_val = load(run.outputs) output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val) output_messages = self._get_output_messages(output_val)

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Sequence, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
@ -18,8 +18,10 @@ def test_interfaces() -> None:
assert str(history) == "System: system\nHuman: human 1\nAI: ai\nHuman: human 2" assert str(history) == "System: system\nHuman: human 1\nAI: ai\nHuman: human 2"
def _get_get_session_history() -> Callable[..., ChatMessageHistory]: def _get_get_session_history(
chat_history_store = {} *, store: Optional[Dict[str, Any]] = None
) -> Callable[..., ChatMessageHistory]:
chat_history_store = store if store is not None else {}
def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory: def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory:
if session_id not in chat_history_store: if session_id not in chat_history_store:
@ -34,13 +36,24 @@ def test_input_messages() -> None:
lambda messages: "you said: " lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
) )
get_session_history = _get_get_session_history() store: Dict = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history) with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {"configurable": {"session_id": "1"}} config: RunnableConfig = {"configurable": {"session_id": "1"}}
output = with_history.invoke([HumanMessage(content="hello")], config) output = with_history.invoke([HumanMessage(content="hello")], config)
assert output == "you said: hello" assert output == "you said: hello"
output = with_history.invoke([HumanMessage(content="good bye")], config) output = with_history.invoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye" assert output == "you said: hello\ngood bye"
assert store == {
"1": ChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
HumanMessage(content="good bye"),
AIMessage(content="you said: hello\ngood bye"),
]
)
}
def test_input_dict() -> None: def test_input_dict() -> None: