mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
core[patch]: Fix runnable with message history (#14629)
Fix bug shown in #14458. Namely, that saving inputs to history fails when the input to base runnable is a list of messages
This commit is contained in:
parent
99743539ae
commit
47451951a1
@ -1,13 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Type,
|
Type,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
@ -25,8 +29,6 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
import inspect
|
|
||||||
from typing import Callable, Dict, Union
|
|
||||||
|
|
||||||
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
|
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
|
||||||
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
||||||
@ -35,13 +37,9 @@ GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
|||||||
class RunnableWithMessageHistory(RunnableBindingBase):
|
class RunnableWithMessageHistory(RunnableBindingBase):
|
||||||
"""A runnable that manages chat message history for another runnable.
|
"""A runnable that manages chat message history for another runnable.
|
||||||
|
|
||||||
Base runnable must have inputs and outputs that can be converted to a list of
|
Base runnable must have inputs and outputs that can be converted to a list of BaseMessages.
|
||||||
BaseMessages.
|
|
||||||
|
|
||||||
RunnableWithMessageHistory must always be called with a config that contains
|
RunnableWithMessageHistory must always be called with a config that contains session_id, e.g. ``{"configurable": {"session_id": "<SESSION_ID>"}}`.
|
||||||
session_id, e.g.:
|
|
||||||
|
|
||||||
``{"configurable": {"session_id": "<SESSION_ID>"}}`
|
|
||||||
|
|
||||||
Example (dict input):
|
Example (dict input):
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -82,9 +80,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
# -> "The inverse of cosine is called arccosine ..."
|
# -> "The inverse of cosine is called arccosine ..."
|
||||||
|
|
||||||
|
|
||||||
Here's an example that uses an in memory chat history, and a factory that
|
Example (get_session_history takes two keys, user_id and conversation id):
|
||||||
takes in two keys (user_id and conversation id) to create a chat history instance.
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
store = {}
|
store = {}
|
||||||
@ -164,46 +160,43 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
"""Initialize RunnableWithMessageHistory.
|
"""Initialize RunnableWithMessageHistory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runnable: The base Runnable to be wrapped.
|
runnable: The base Runnable to be wrapped. Must take as input one of:
|
||||||
|
1. A sequence of BaseMessages
|
||||||
Must take as input one of:
|
2. A dict with one key for all messages
|
||||||
- A sequence of BaseMessages
|
3. A dict with one key for the current input string/message(s) and
|
||||||
- A dict with one key for all messages
|
|
||||||
- A dict with one key for the current input string/message(s) and
|
|
||||||
a separate key for historical messages. If the input key points
|
a separate key for historical messages. If the input key points
|
||||||
to a string, it will be treated as a HumanMessage in history.
|
to a string, it will be treated as a HumanMessage in history.
|
||||||
|
|
||||||
Must return as output one of:
|
Must return as output one of:
|
||||||
- A string which can be treated as an AIMessage
|
1. A string which can be treated as an AIMessage
|
||||||
- A BaseMessage or sequence of BaseMessages
|
2. A BaseMessage or sequence of BaseMessages
|
||||||
- A dict with a key for a BaseMessage or sequence of BaseMessages
|
3. A dict with a key for a BaseMessage or sequence of BaseMessages
|
||||||
|
|
||||||
get_session_history: Function that returns a new BaseChatMessageHistory.
|
get_session_history: Function that returns a new BaseChatMessageHistory.
|
||||||
This function should either take a single positional argument
|
This function should either take a single positional argument
|
||||||
`session_id` of type string and return a corresponding
|
`session_id` of type string and return a corresponding
|
||||||
chat message history instance.
|
chat message history instance.
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
```python
|
|
||||||
def get_session_history(
|
def get_session_history(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
*,
|
*,
|
||||||
user_id: Optional[str]=None
|
user_id: Optional[str]=None
|
||||||
) -> BaseChatMessageHistory:
|
) -> BaseChatMessageHistory:
|
||||||
...
|
...
|
||||||
```
|
|
||||||
|
|
||||||
Or it should take keyword arguments that match the keys of
|
Or it should take keyword arguments that match the keys of
|
||||||
`session_history_config_specs` and return a corresponding
|
`session_history_config_specs` and return a corresponding
|
||||||
chat message history instance.
|
chat message history instance.
|
||||||
|
|
||||||
```python
|
.. code-block:: python
|
||||||
|
|
||||||
def get_session_history(
|
def get_session_history(
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
) -> BaseChatMessageHistory:
|
) -> BaseChatMessageHistory:
|
||||||
...
|
...
|
||||||
```
|
|
||||||
|
|
||||||
input_messages_key: Must be specified if the base runnable accepts a dict
|
input_messages_key: Must be specified if the base runnable accepts a dict
|
||||||
as input.
|
as input.
|
||||||
@ -350,6 +343,12 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
input_val = inputs[self.input_messages_key or "input"]
|
input_val = inputs[self.input_messages_key or "input"]
|
||||||
input_messages = self._get_input_messages(input_val)
|
input_messages = self._get_input_messages(input_val)
|
||||||
|
|
||||||
|
# If historic messages were prepended to the input messages, remove them to
|
||||||
|
# avoid adding duplicate messages to history.
|
||||||
|
if not self.history_messages_key:
|
||||||
|
historic_messages = config["configurable"]["message_history"].messages
|
||||||
|
input_messages = input_messages[len(historic_messages) :]
|
||||||
|
|
||||||
# Get the output messages
|
# Get the output messages
|
||||||
output_val = load(run.outputs)
|
output_val = load(run.outputs)
|
||||||
output_messages = self._get_output_messages(output_val)
|
output_messages = self._get_output_messages(output_val)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Callable, Dict, List, Sequence, Union
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
@ -18,8 +18,10 @@ def test_interfaces() -> None:
|
|||||||
assert str(history) == "System: system\nHuman: human 1\nAI: ai\nHuman: human 2"
|
assert str(history) == "System: system\nHuman: human 1\nAI: ai\nHuman: human 2"
|
||||||
|
|
||||||
|
|
||||||
def _get_get_session_history() -> Callable[..., ChatMessageHistory]:
|
def _get_get_session_history(
|
||||||
chat_history_store = {}
|
*, store: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Callable[..., ChatMessageHistory]:
|
||||||
|
chat_history_store = store if store is not None else {}
|
||||||
|
|
||||||
def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory:
|
def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory:
|
||||||
if session_id not in chat_history_store:
|
if session_id not in chat_history_store:
|
||||||
@ -34,13 +36,24 @@ def test_input_messages() -> None:
|
|||||||
lambda messages: "you said: "
|
lambda messages: "you said: "
|
||||||
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
|
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
|
||||||
)
|
)
|
||||||
get_session_history = _get_get_session_history()
|
store: Dict = {}
|
||||||
|
get_session_history = _get_get_session_history(store=store)
|
||||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||||
config: RunnableConfig = {"configurable": {"session_id": "1"}}
|
config: RunnableConfig = {"configurable": {"session_id": "1"}}
|
||||||
output = with_history.invoke([HumanMessage(content="hello")], config)
|
output = with_history.invoke([HumanMessage(content="hello")], config)
|
||||||
assert output == "you said: hello"
|
assert output == "you said: hello"
|
||||||
output = with_history.invoke([HumanMessage(content="good bye")], config)
|
output = with_history.invoke([HumanMessage(content="good bye")], config)
|
||||||
assert output == "you said: hello\ngood bye"
|
assert output == "you said: hello\ngood bye"
|
||||||
|
assert store == {
|
||||||
|
"1": ChatMessageHistory(
|
||||||
|
messages=[
|
||||||
|
HumanMessage(content="hello"),
|
||||||
|
AIMessage(content="you said: hello"),
|
||||||
|
HumanMessage(content="good bye"),
|
||||||
|
AIMessage(content="you said: hello\ngood bye"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_input_dict() -> None:
|
def test_input_dict() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user