FEATURE: Runnable with message history (#13418)

Add RunnableWithMessageHistory class that can wrap certain runnables and manages chat history for them.
This commit is contained in:
Bagatur
2023-11-17 12:00:01 -08:00
committed by GitHub
parent 0fc3af8932
commit 2e2114d2d0
15 changed files with 939 additions and 21 deletions

View File

@@ -85,6 +85,9 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
variable_name: str
"""Name of variable to use as messages."""
def __init__(self, variable_name: str, **kwargs: Any):
return super().__init__(variable_name=variable_name, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.

View File

@@ -42,7 +42,6 @@ if TYPE_CHECKING:
RunnableWithFallbacks as RunnableWithFallbacksT,
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import BaseModel, Field, create_model
@@ -298,7 +297,7 @@ class Runnable(Generic[Input, Output], ABC):
)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
"""List configurable fields for this runnable."""
return []
@@ -1357,7 +1356,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return self.last.get_output_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.steps for spec in step.config_specs
)
@@ -1885,7 +1884,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.steps.values() for spec in step.config_specs
)
@@ -2591,7 +2590,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.bound.config_specs
@classmethod
@@ -2763,7 +2762,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
return self.bound.get_output_schema(merge_configs(self.config, config))
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.bound.config_specs
@classmethod

View File

@@ -147,7 +147,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
return super().get_input_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in (

View File

@@ -209,7 +209,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
fields: Dict[str, AnyConfigurableField]
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
[
ConfigurableFieldSpec(
@@ -300,7 +300,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
default_key: str = "default"
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
with _enums_for_spec_lock:
if which_enum := _enums_for_spec.get(self.which):
pass

View File

@@ -112,7 +112,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return self.runnable.get_output_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in [self.runnable, *self.fallbacks]

View File

@@ -0,0 +1,288 @@
from __future__ import annotations
import asyncio
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
Union,
)
from langchain.load import load
from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.chat_history import BaseChatMessageHistory
from langchain.schema.runnable.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.utils import (
ConfigurableFieldSpec,
get_unique_config_specs,
)
if TYPE_CHECKING:
from langchain.callbacks.tracers.schemas import Run
from langchain.schema.messages import BaseMessage
from langchain.schema.runnable.config import RunnableConfig
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
class RunnableWithMessageHistory(RunnableBindingBase):
"""A runnable that manages chat message history for another runnable.
Base runnable must have inputs and outputs that can be converted to a list of
BaseMessages.
RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.:
``{"configurable": {"session_id": "<SESSION_ID>"}}``
Example (dict input):
.. code-block:: python
from typing import Optional
from langchain.chat_models import ChatAnthropic
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema.runnable.history import RunnableWithMessageHistory
prompt = ChatPromptTemplate.from_messages([
("system", "You're an assistant who's good at {ability}"),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
])
chain = prompt | ChatAnthropic(model="claude-2")
chain_with_history = RunnableWithMessageHistory(
chain,
RedisChatMessageHistory,
input_messages_key="question",
history_messages_key="history",
)
chain_with_history.invoke(
{"ability": "math", "question": "What does cosine mean?"},
config={"configurable": {"session_id": "foo"}}
)
# -> "Cosine is ..."
chain_with_history.invoke(
{"ability": "math", "question": "What's its inverse"},
config={"configurable": {"session_id": "foo"}}
)
# -> "The inverse of cosine is called arccosine ..."
""" # noqa: E501
get_session_history: GetSessionHistoryCallable
input_messages_key: Optional[str] = None
output_messages_key: Optional[str] = None
history_messages_key: Optional[str] = None
def __init__(
self,
runnable: Runnable[
MessagesOrDictWithMessages,
Union[str, BaseMessage, MessagesOrDictWithMessages],
],
get_session_history: GetSessionHistoryCallable,
*,
input_messages_key: Optional[str] = None,
output_messages_key: Optional[str] = None,
history_messages_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize RunnableWithMessageHistory.
Args:
runnable: The base Runnable to be wrapped.
Must take as input one of:
- A sequence of BaseMessages
- A dict with one key for all messages
- A dict with one key for the current input string/message(s) and
a separate key for historical messages. If the input key points
to a string, it will be treated as a HumanMessage in history.
Must return as output one of:
- A string which can be treated as an AIMessage
- A BaseMessage or sequence of BaseMessages
- A dict with a key for a BaseMessage or sequence of BaseMessages
get_session_history: Function that returns a new BaseChatMessageHistory
given a session id. Should take a single
positional argument `session_id` which is a string and a named argument
`user_id` which can be a string or None. e.g.:
```python
def get_session_history(
session_id: str,
*,
user_id: Optional[str]=None
) -> BaseChatMessageHistory:
...
```
input_messages_key: Must be specified if the base runnable accepts a dict
as input.
output_messages_key: Must be specified if the base runnable returns a dict
as output.
history_messages_key: Must be specified if the base runnable accepts a dict
as input and expects a separate key for historical messages.
**kwargs: Arbitrary additional kwargs to pass to parent class
``RunnableBindingBase`` init.
""" # noqa: E501
history_chain: Runnable = RunnableLambda(
self._enter_history, self._aenter_history
).with_config(run_name="load_history")
messages_key = history_messages_key or input_messages_key
if messages_key:
history_chain = RunnablePassthrough.assign(
**{messages_key: history_chain}
).with_config(run_name="insert_history")
bound = (
history_chain | runnable.with_listeners(on_end=self._exit_history)
).with_config(run_name="RunnableWithMessageHistory")
super().__init__(
get_session_history=get_session_history,
input_messages_key=input_messages_key,
output_messages_key=output_messages_key,
bound=bound,
history_messages_key=history_messages_key,
**kwargs,
)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
super().config_specs
+ [
ConfigurableFieldSpec(
id="session_id",
annotation=str,
name="Session ID",
description="Unique identifier for a session.",
default="",
),
]
)
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
super_schema = super().get_input_schema(config)
if super_schema.__custom_root_type__ is not None:
from langchain.schema.messages import BaseMessage
fields: Dict = {}
if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]],
...,
)
elif self.input_messages_key:
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
else:
fields["__root__"] = (Sequence[BaseMessage], ...)
if self.history_messages_key:
fields[self.history_messages_key] = (Sequence[BaseMessage], ...)
return create_model( # type: ignore[call-overload]
"RunnableWithChatHistoryInput",
**fields,
)
else:
return super_schema
def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
) -> List[BaseMessage]:
from langchain.schema.messages import BaseMessage
if isinstance(input_val, str):
from langchain.schema.messages import HumanMessage
return [HumanMessage(content=input_val)]
elif isinstance(input_val, BaseMessage):
return [input_val]
elif isinstance(input_val, (list, tuple)):
return list(input_val)
else:
raise ValueError(
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
f"Got {input_val}."
)
def _get_output_messages(
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]:
from langchain.schema.messages import BaseMessage
if isinstance(output_val, dict):
output_val = output_val[self.output_messages_key or "output"]
if isinstance(output_val, str):
from langchain.schema.messages import AIMessage
return [AIMessage(content=output_val)]
elif isinstance(output_val, BaseMessage):
return [output_val]
elif isinstance(output_val, (list, tuple)):
return list(output_val)
else:
raise ValueError()
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
hist = config["configurable"]["message_history"]
# return only historic messages
if self.history_messages_key:
return hist.messages.copy()
# return all messages
else:
input_val = (
input if not self.input_messages_key else input[self.input_messages_key]
)
return hist.messages.copy() + self._get_input_messages(input_val)
async def _aenter_history(
self, input: Dict[str, Any], config: RunnableConfig
) -> List[BaseMessage]:
return await asyncio.get_running_loop().run_in_executor(
None, self._enter_history, input, config
)
def _exit_history(self, run: Run, config: RunnableConfig) -> None:
hist = config["configurable"]["message_history"]
# Get the input messages
inputs = load(run.inputs)
input_val = inputs[self.input_messages_key or "input"]
input_messages = self._get_input_messages(input_val)
# Get the output messages
output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val)
for m in input_messages + output_messages:
hist.add_message(m)
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs)
# extract session_id
if "session_id" not in config.get("configurable", {}):
example_input = {self.input_messages_key: "foo"}
example_config = {"configurable": {"session_id": "123"}}
raise ValueError(
"session_id_id is required."
" Pass it in as part of the config argument to .invoke() or .stream()"
f"\neg. chain.invoke({example_input}, {example_config})"
)
# attach message_history
session_id = config["configurable"]["session_id"]
config["configurable"]["message_history"] = self.get_session_history(session_id)
return config

View File

@@ -14,7 +14,6 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
@@ -334,7 +333,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return super().get_output_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.mapper.config_specs
def invoke(

View File

@@ -8,7 +8,6 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Union,
cast,
)
@@ -55,7 +54,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
runnables: Mapping[str, Runnable[Any, Output]]
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.runnables.values() for spec in step.config_specs
)

View File

@@ -308,7 +308,7 @@ class ConfigurableFieldSpec(NamedTuple):
def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec],
) -> Sequence[ConfigurableFieldSpec]:
) -> List[ConfigurableFieldSpec]:
"""Get the unique config specs from a sequence of config specs."""
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
unique: List[ConfigurableFieldSpec] = []

View File

@@ -0,0 +1,231 @@
from typing import Any, Callable, Sequence, Union
from langchain.memory import ChatMessageHistory
from langchain.pydantic_v1 import BaseModel
from langchain.schema import AIMessage, BaseMessage, HumanMessage
from langchain.schema.runnable import RunnableConfig, RunnableLambda
from langchain.schema.runnable.history import RunnableWithMessageHistory
def _get_get_session_history() -> Callable[..., ChatMessageHistory]:
chat_history_store = {}
def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory:
if session_id not in chat_history_store:
chat_history_store[session_id] = ChatMessageHistory()
return chat_history_store[session_id]
return get_session_history
def test_input_messages() -> None:
runnable = RunnableLambda(
lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {"configurable": {"session_id": "1"}}
output = with_history.invoke([HumanMessage(content="hello")], config)
assert output == "you said: hello"
output = with_history.invoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye"
def test_input_dict() -> None:
runnable = RunnableLambda(
lambda input: "you said: "
+ "\n".join(
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage)
)
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable, get_session_history, input_messages_key="messages"
)
config: RunnableConfig = {"configurable": {"session_id": "2"}}
output = with_history.invoke({"messages": [HumanMessage(content="hello")]}, config)
assert output == "you said: hello"
output = with_history.invoke(
{"messages": [HumanMessage(content="good bye")]}, config
)
assert output == "you said: hello\ngood bye"
def test_input_dict_with_history_key() -> None:
runnable = RunnableLambda(
lambda input: "you said: "
+ "\n".join(
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)]
+ [input["input"]]
)
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
)
config: RunnableConfig = {"configurable": {"session_id": "3"}}
output = with_history.invoke({"input": "hello"}, config)
assert output == "you said: hello"
output = with_history.invoke({"input": "good bye"}, config)
assert output == "you said: hello\ngood bye"
def test_output_message() -> None:
runnable = RunnableLambda(
lambda input: AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
)
config: RunnableConfig = {"configurable": {"session_id": "4"}}
output = with_history.invoke({"input": "hello"}, config)
assert output == AIMessage(content="you said: hello")
output = with_history.invoke({"input": "good bye"}, config)
assert output == AIMessage(content="you said: hello\ngood bye")
def test_output_messages() -> None:
runnable = RunnableLambda(
lambda input: [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
]
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
)
config: RunnableConfig = {"configurable": {"session_id": "5"}}
output = with_history.invoke({"input": "hello"}, config)
assert output == [AIMessage(content="you said: hello")]
output = with_history.invoke({"input": "good bye"}, config)
assert output == [AIMessage(content="you said: hello\ngood bye")]
def test_output_dict() -> None:
runnable = RunnableLambda(
lambda input: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
]
}
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
output_messages_key="output",
)
config: RunnableConfig = {"configurable": {"session_id": "6"}}
output = with_history.invoke({"input": "hello"}, config)
assert output == {"output": [AIMessage(content="you said: hello")]}
output = with_history.invoke({"input": "good bye"}, config)
assert output == {"output": [AIMessage(content="you said: hello\ngood bye")]}
def test_get_input_schema_input_dict() -> None:
class RunnableWithChatHistoryInput(BaseModel):
input: Union[str, BaseMessage, Sequence[BaseMessage]]
history: Sequence[BaseMessage]
runnable = RunnableLambda(
lambda input: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
]
}
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
output_messages_key="output",
)
assert (
with_history.get_input_schema().schema()
== RunnableWithChatHistoryInput.schema()
)
def test_get_input_schema_input_messages() -> None:
class RunnableWithChatHistoryInput(BaseModel):
__root__: Sequence[BaseMessage]
runnable = RunnableLambda(
lambda messages: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in messages
if isinstance(m, HumanMessage)
]
)
)
]
}
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable, get_session_history, output_messages_key="output"
)
assert (
with_history.get_input_schema().schema()
== RunnableWithChatHistoryInput.schema()
)