FEATURE: Runnable with message history (#13418)

Add RunnableWithMessageHistory class that can wrap certain runnables and manages chat history for them.
This commit is contained in:
Bagatur 2023-11-17 12:00:01 -08:00 committed by GitHub
parent 0fc3af8932
commit 2e2114d2d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 939 additions and 21 deletions

View File

@ -0,0 +1,396 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "6a4becbd-238e-4c1d-a02d-08e61fbc3763",
"metadata": {},
"source": [
"# Add message history (memory)\n",
"\n",
"The `RunnableWithMessageHistory` let's us add message history to certain types of chains.\n",
"\n",
"Specifically, it can be used for any Runnable that takes as input one of\n",
"* a sequence of `BaseMessage`\n",
"* a dict with a key that takes a sequence of `BaseMessage`\n",
"* a dict with a key that takes the latest message(s) as a string or sequence of `BaseMessage`, and a separate key that takes historical messages\n",
"\n",
"And returns as output one of\n",
"* a string that can be treated as the contents of an `AIMessage`\n",
"* a sequence of `BaseMessage`\n",
"* a dict with a key that contains a sequence of `BaseMessage`\n",
"\n",
"Let's take a look at some examples to see how it works."
]
},
{
"cell_type": "markdown",
"id": "6bca45e5-35d9-4603-9ca9-6ac0ce0e35cd",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"We'll use Redis to store our chat message histories and Anthropic's claude-2 model so we'll need to install the following dependencies:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "477d04b3-c2b6-4ba5-962f-492c0d625cd5",
"metadata": {},
"outputs": [],
"source": [
"!pip install -U langchain redis anthropic"
]
},
{
"cell_type": "markdown",
"id": "93776323-d6b8-4912-bb6a-867c5e655f46",
"metadata": {},
"source": [
"Set your [Anthropic API key](https://console.anthropic.com/):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c7f56f69-d2f1-4a21-990c-b5551eb012fa",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"ANTHROPIC_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "markdown",
"id": "6a0ec9e0-7b1c-4c6f-b570-e61d520b47c6",
"metadata": {},
"source": [
"Start a local Redis Stack server if we don't have an existing Redis deployment to connect to:\n",
"```bash\n",
"docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cd6a250e-17fe-4368-a39d-1fe6b2cbde68",
"metadata": {},
"outputs": [],
"source": [
"REDIS_URL = \"redis://localhost:6379/0\""
]
},
{
"cell_type": "markdown",
"id": "36f43b87-655c-4f64-aa7b-bd8c1955d8e5",
"metadata": {},
"source": [
"### [LangSmith](/docs/langsmith)\n",
"\n",
"LangSmith is especially useful for something like message history injection, where it can be hard to otherwise understand what the inputs are to various parts of the chain.\n",
"\n",
"Note that LangSmith is not needed, but it is helpful.\n",
"If you do want to use LangSmith, after you sign up at the link above, make sure to uncoment the below and set your environment variables to start logging traces:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2afc1556-8da1-4499-ba11-983b66c58b18",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "markdown",
"id": "1a5a632e-ba9e-4488-b586-640ad5494f62",
"metadata": {},
"source": [
"## Example: Dict input, message output\n",
"\n",
"Let's create a simple chain that takes a dict as input and returns a BaseMessage.\n",
"\n",
"In this case the `\"question\"` key in the input represents our input message, and the `\"history\"` key is where our historical messages will be injected."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2a150d6f-8878-4950-8634-a608c5faad56",
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional\n",
"\n",
"from langchain.chat_models import ChatAnthropic\n",
"from langchain.memory.chat_message_histories import RedisChatMessageHistory\n",
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"from langchain.schema.chat_history import BaseChatMessageHistory\n",
"from langchain.schema.runnable.history import RunnableWithMessageHistory"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3185edba-4eb6-4b32-80c6-577c0d19af97",
"metadata": {},
"outputs": [],
"source": [
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You're an assistant who's good at {ability}\"),\n",
" MessagesPlaceholder(variable_name=\"history\"),\n",
" (\"human\", \"{question}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | ChatAnthropic(model=\"claude-2\")"
]
},
{
"cell_type": "markdown",
"id": "f9d81796-ce61-484c-89e2-6c567d5e54ef",
"metadata": {},
"source": [
"### Adding message history\n",
"\n",
"To add message history to our original chain we wrap it in the `RunnableWithMessageHistory` class.\n",
"\n",
"Crucially, we also need to define a method that takes a session_id string and based on it returns a `BaseChatMessageHistory`. Given the same input, this method should return an equivalent output.\n",
"\n",
"In this case we'll also want to specify `input_messages_key` (the key to be treated as the latest input message) and `history_messages_key` (the key to add historical messages to)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ca7c64d8-e138-4ef8-9734-f82076c47d80",
"metadata": {},
"outputs": [],
"source": [
"chain_with_history = RunnableWithMessageHistory(\n",
" chain,\n",
" lambda session_id: RedisChatMessageHistory(session_id, url=REDIS_URL),\n",
" input_messages_key=\"question\",\n",
" history_messages_key=\"history\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "37eefdec-9901-4650-b64c-d3c097ed5f4d",
"metadata": {},
"source": [
"## Invoking with config\n",
"\n",
"Whenever we call our chain with message history, we need to include a config that contains the `session_id`\n",
"```python\n",
"config={\"configurable\": {\"session_id\": \"<SESSION_ID>\"}}\n",
"```\n",
"\n",
"Given the same configuration, our chain should be pulling from the same chat message history."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a85bcc22-ca4c-4ad5-9440-f94be7318f3e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=' Cosine is one of the basic trigonometric functions in mathematics. It is defined as the ratio of the adjacent side to the hypotenuse in a right triangle.\\n\\nSome key properties and facts about cosine:\\n\\n- It is denoted by cos(θ), where θ is the angle in a right triangle. \\n\\n- The cosine of an acute angle is always positive. For angles greater than 90 degrees, cosine can be negative.\\n\\n- Cosine is one of the three main trig functions along with sine and tangent.\\n\\n- The cosine of 0 degrees is 1. As the angle increases towards 90 degrees, the cosine value decreases towards 0.\\n\\n- The range of values for cosine is -1 to 1.\\n\\n- The cosine function maps angles in a circle to the x-coordinate on the unit circle.\\n\\n- Cosine is used to find adjacent side lengths in right triangles, and has many other applications in mathematics, physics, engineering and more.\\n\\n- Key cosine identities include: cos(A+B) = cosAcosB sinAsinB and cos(2A) = cos^2(A) sin^2(A)\\n\\nSo in summary, cosine is a fundamental trig')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain_with_history.invoke(\n",
" {\"ability\": \"math\", \"question\": \"What does cosine mean?\"},\n",
" config={\"configurable\": {\"session_id\": \"foobar\"}},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ab29abd3-751f-41ce-a1b0-53f6b565e79d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=' The inverse of the cosine function is called the arccosine or inverse cosine, often denoted as cos-1(x) or arccos(x).\\n\\nThe key properties and facts about arccosine:\\n\\n- It is defined as the angle θ between 0 and π radians whose cosine is x. So arccos(x) = θ such that cos(θ) = x.\\n\\n- The range of arccosine is 0 to π radians (0 to 180 degrees).\\n\\n- The domain of arccosine is -1 to 1. \\n\\n- arccos(cos(θ)) = θ for values of θ from 0 to π radians.\\n\\n- arccos(x) is the angle in a right triangle whose adjacent side is x and hypotenuse is 1.\\n\\n- arccos(0) = 90 degrees. As x increases from 0 to 1, arccos(x) decreases from 90 to 0 degrees.\\n\\n- arccos(1) = 0 degrees. arccos(-1) = 180 degrees.\\n\\n- The graph of y = arccos(x) is part of the unit circle, restricted to x')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain_with_history.invoke(\n",
" {\"ability\": \"math\", \"question\": \"What's its inverse\"},\n",
" config={\"configurable\": {\"session_id\": \"foobar\"}},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "da3d1feb-b4bb-4624-961c-7db2e1180df7",
"metadata": {},
"source": [
":::tip [Langsmith trace](https://smith.langchain.com/public/863a003b-7ca8-4b24-be9e-d63ec13c106e/r)\n",
":::"
]
},
{
"cell_type": "markdown",
"id": "61d5115e-64a1-4ad5-b676-8afd4ef6093e",
"metadata": {},
"source": [
"Looking at the Langsmith trace for the second call, we can see that when constructing the prompt, a \"history\" variable has been injected which is a list of two messages (our first input and first output)."
]
},
{
"cell_type": "markdown",
"id": "028cf151-6cd5-4533-b3cf-c8d735554647",
"metadata": {},
"source": [
"## Example: messages input, dict output"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "0bb446b5-6251-45fe-a92a-4c6171473c53",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'output_message': AIMessage(content=' Here is a summary of Simone de Beauvoir\\'s views on free will:\\n\\n- De Beauvoir was an existentialist philosopher and believed strongly in the concept of free will. She rejected the idea that human nature or instincts determine behavior.\\n\\n- Instead, de Beauvoir argued that human beings define their own essence or nature through their actions and choices. As she famously wrote, \"One is not born, but rather becomes, a woman.\"\\n\\n- De Beauvoir believed that while individuals are situated in certain cultural contexts and social conditions, they still have agency and the ability to transcend these situations. Freedom comes from choosing one\\'s attitude toward these constraints.\\n\\n- She emphasized the radical freedom and responsibility of the individual. We are \"condemned to be free\" because we cannot escape making choices and taking responsibility for our choices. \\n\\n- De Beauvoir felt that many people evade their freedom and responsibility by adopting rigid mindsets, ideologies, or conforming uncritically to social roles.\\n\\n- She advocated for the recognition of ambiguity in the human condition and warned against the quest for absolute rules that deny freedom and responsibility. Authentic living involves embracing ambiguity.\\n\\nIn summary, de Beauvoir promoted an existential ethics')}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.schema.messages import HumanMessage\n",
"from langchain.schema.runnable import RunnableMap\n",
"\n",
"chain = RunnableMap({\"output_message\": ChatAnthropic(model=\"claude-2\")})\n",
"chain_with_history = RunnableWithMessageHistory(\n",
" chain,\n",
" lambda session_id: RedisChatMessageHistory(session_id, url=REDIS_URL),\n",
" output_messages_key=\"output_message\",\n",
")\n",
"\n",
"chain_with_history.invoke(\n",
" [HumanMessage(content=\"What did Simone de Beauvoir believe about free will\")],\n",
" config={\"configurable\": {\"session_id\": \"baz\"}},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "601ce3ff-aea8-424d-8e54-fd614256af4f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'output_message': AIMessage(content=\" There are many similarities between Simone de Beauvoir's views on free will and those of Jean-Paul Sartre, though some key differences emerge as well:\\n\\nSimilarities with Sartre:\\n\\n- Both were existentialist thinkers who rejected determinism and emphasized human freedom and responsibility.\\n\\n- They agreed that existence precedes essence - there is no predefined human nature that determines who we are.\\n\\n- Individuals must define themselves through their choices and actions. This leads to anxiety but also freedom.\\n\\n- The human condition is characterized by ambiguity and uncertainty, rather than fixed meanings/values.\\n\\n- Both felt that most people evade their freedom through self-deception, conformity, or adopting collective identities/values uncritically.\\n\\nDifferences from Sartre: \\n\\n- Sartre placed more emphasis on the burden and anguish of radical freedom. De Beauvoir focused more on its positive potential.\\n\\n- De Beauvoir critiqued Sartre's premise that human relations are necessarily conflictual. She saw more potential for mutual recognition.\\n\\n- Sartre saw the Other's gaze as a threat to freedom. De Beauvoir put more stress on how the Other's gaze can confirm\")}"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain_with_history.invoke(\n",
" [HumanMessage(content=\"How did this compare to Sartre\")],\n",
" config={\"configurable\": {\"session_id\": \"baz\"}},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b898d1b1-11e6-4d30-a8dd-cc5e45533611",
"metadata": {},
"source": [
":::tip [LangSmith trace](https://smith.langchain.com/public/f6c3e1d1-a49d-4955-a9fa-c6519df74fa7/r)\n",
":::"
]
},
{
"cell_type": "markdown",
"id": "1724292c-01c6-44bb-83e8-9cdb6bf01483",
"metadata": {},
"source": [
"## More examples\n",
"\n",
"We could also do any of the below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd89240b-5a25-48f8-9568-5c1127f9ffad",
"metadata": {},
"outputs": [],
"source": [
"from operator import itemgetter\n",
"\n",
"# messages in, messages out\n",
"RunnableWithMessageHistory(\n",
" ChatAnthropic(model=\"claude-2\"),\n",
" lambda session_id: RedisChatMessageHistory(session_id, url=REDIS_URL),\n",
")\n",
"\n",
"# dict with single key for all messages in, messages out\n",
"RunnableWithMessageHistory(\n",
" itemgetter(\"input_messages\") | ChatAnthropic(model=\"claude-2\"),\n",
" lambda session_id: RedisChatMessageHistory(session_id, url=REDIS_URL),\n",
" input_messages_key=\"input_messages\",\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv",
"language": "python",
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -58,8 +58,6 @@
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import FAISS\n",
"\n",
"from langchain.document_loaders import TextLoader\n",
"\n",
"loader = TextLoader(\"../../../extras/modules/state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",

View File

@ -61,8 +61,6 @@
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import FAISS\n",
"\n",
"from langchain.document_loaders import TextLoader\n",
"\n",
"loader = TextLoader(\"../../../extras/modules/state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",

View File

@ -54,11 +54,10 @@
}
],
"source": [
"from langchain.vectorstores.vearch import Vearch\n",
"\n",
"from langchain.document_loaders import TextLoader\n",
"from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain.vectorstores.vearch import Vearch\n",
"from transformers import AutoModel, AutoTokenizer\n",
"\n",
"# repalce to your local model path\n",

View File

@ -1051,6 +1051,14 @@
":::"
]
},
{
"cell_type": "markdown",
"id": "fdf6c7e0-84f8-4747-b2ae-e84315152bd9",
"metadata": {},
"source": [
"Here we've gone over how to add chain logic for incorporating historical outputs. But how do we actually store and retrieve historical outputs for different sessions? For that check out the LCEL [How to add message history (memory)](/docs/expression_language/how_to/message_history) page."
]
},
{
"cell_type": "markdown",
"id": "580e18de-132d-4009-ba67-4aaf2c7717a2",

View File

@ -85,6 +85,9 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
variable_name: str
"""Name of variable to use as messages."""
def __init__(self, variable_name: str, **kwargs: Any):
return super().__init__(variable_name=variable_name, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.

View File

@ -42,7 +42,6 @@ if TYPE_CHECKING:
RunnableWithFallbacks as RunnableWithFallbacksT,
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import BaseModel, Field, create_model
@ -298,7 +297,7 @@ class Runnable(Generic[Input, Output], ABC):
)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
"""List configurable fields for this runnable."""
return []
@ -1357,7 +1356,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return self.last.get_output_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.steps for spec in step.config_specs
)
@ -1885,7 +1884,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.steps.values() for spec in step.config_specs
)
@ -2591,7 +2590,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.bound.config_specs
@classmethod
@ -2763,7 +2762,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
return self.bound.get_output_schema(merge_configs(self.config, config))
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.bound.config_specs
@classmethod

View File

@ -147,7 +147,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
return super().get_input_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in (

View File

@ -209,7 +209,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
fields: Dict[str, AnyConfigurableField]
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
[
ConfigurableFieldSpec(
@ -300,7 +300,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
default_key: str = "default"
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
with _enums_for_spec_lock:
if which_enum := _enums_for_spec.get(self.which):
pass

View File

@ -112,7 +112,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return self.runnable.get_output_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in [self.runnable, *self.fallbacks]

View File

@ -0,0 +1,288 @@
from __future__ import annotations
import asyncio
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
Union,
)
from langchain.load import load
from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.chat_history import BaseChatMessageHistory
from langchain.schema.runnable.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.utils import (
ConfigurableFieldSpec,
get_unique_config_specs,
)
if TYPE_CHECKING:
from langchain.callbacks.tracers.schemas import Run
from langchain.schema.messages import BaseMessage
from langchain.schema.runnable.config import RunnableConfig
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
class RunnableWithMessageHistory(RunnableBindingBase):
"""A runnable that manages chat message history for another runnable.
Base runnable must have inputs and outputs that can be converted to a list of
BaseMessages.
RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.:
``{"configurable": {"session_id": "<SESSION_ID>"}}``
Example (dict input):
.. code-block:: python
from typing import Optional
from langchain.chat_models import ChatAnthropic
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema.runnable.history import RunnableWithMessageHistory
prompt = ChatPromptTemplate.from_messages([
("system", "You're an assistant who's good at {ability}"),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
])
chain = prompt | ChatAnthropic(model="claude-2")
chain_with_history = RunnableWithMessageHistory(
chain,
RedisChatMessageHistory,
input_messages_key="question",
history_messages_key="history",
)
chain_with_history.invoke(
{"ability": "math", "question": "What does cosine mean?"},
config={"configurable": {"session_id": "foo"}}
)
# -> "Cosine is ..."
chain_with_history.invoke(
{"ability": "math", "question": "What's its inverse"},
config={"configurable": {"session_id": "foo"}}
)
# -> "The inverse of cosine is called arccosine ..."
""" # noqa: E501
get_session_history: GetSessionHistoryCallable
input_messages_key: Optional[str] = None
output_messages_key: Optional[str] = None
history_messages_key: Optional[str] = None
def __init__(
self,
runnable: Runnable[
MessagesOrDictWithMessages,
Union[str, BaseMessage, MessagesOrDictWithMessages],
],
get_session_history: GetSessionHistoryCallable,
*,
input_messages_key: Optional[str] = None,
output_messages_key: Optional[str] = None,
history_messages_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize RunnableWithMessageHistory.
Args:
runnable: The base Runnable to be wrapped.
Must take as input one of:
- A sequence of BaseMessages
- A dict with one key for all messages
- A dict with one key for the current input string/message(s) and
a separate key for historical messages. If the input key points
to a string, it will be treated as a HumanMessage in history.
Must return as output one of:
- A string which can be treated as an AIMessage
- A BaseMessage or sequence of BaseMessages
- A dict with a key for a BaseMessage or sequence of BaseMessages
get_session_history: Function that returns a new BaseChatMessageHistory
given a session id. Should take a single
positional argument `session_id` which is a string and a named argument
`user_id` which can be a string or None. e.g.:
```python
def get_session_history(
session_id: str,
*,
user_id: Optional[str]=None
) -> BaseChatMessageHistory:
...
```
input_messages_key: Must be specified if the base runnable accepts a dict
as input.
output_messages_key: Must be specified if the base runnable returns a dict
as output.
history_messages_key: Must be specified if the base runnable accepts a dict
as input and expects a separate key for historical messages.
**kwargs: Arbitrary additional kwargs to pass to parent class
``RunnableBindingBase`` init.
""" # noqa: E501
history_chain: Runnable = RunnableLambda(
self._enter_history, self._aenter_history
).with_config(run_name="load_history")
messages_key = history_messages_key or input_messages_key
if messages_key:
history_chain = RunnablePassthrough.assign(
**{messages_key: history_chain}
).with_config(run_name="insert_history")
bound = (
history_chain | runnable.with_listeners(on_end=self._exit_history)
).with_config(run_name="RunnableWithMessageHistory")
super().__init__(
get_session_history=get_session_history,
input_messages_key=input_messages_key,
output_messages_key=output_messages_key,
bound=bound,
history_messages_key=history_messages_key,
**kwargs,
)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
super().config_specs
+ [
ConfigurableFieldSpec(
id="session_id",
annotation=str,
name="Session ID",
description="Unique identifier for a session.",
default="",
),
]
)
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
super_schema = super().get_input_schema(config)
if super_schema.__custom_root_type__ is not None:
from langchain.schema.messages import BaseMessage
fields: Dict = {}
if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]],
...,
)
elif self.input_messages_key:
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
else:
fields["__root__"] = (Sequence[BaseMessage], ...)
if self.history_messages_key:
fields[self.history_messages_key] = (Sequence[BaseMessage], ...)
return create_model( # type: ignore[call-overload]
"RunnableWithChatHistoryInput",
**fields,
)
else:
return super_schema
def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
) -> List[BaseMessage]:
from langchain.schema.messages import BaseMessage
if isinstance(input_val, str):
from langchain.schema.messages import HumanMessage
return [HumanMessage(content=input_val)]
elif isinstance(input_val, BaseMessage):
return [input_val]
elif isinstance(input_val, (list, tuple)):
return list(input_val)
else:
raise ValueError(
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
f"Got {input_val}."
)
def _get_output_messages(
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]:
from langchain.schema.messages import BaseMessage
if isinstance(output_val, dict):
output_val = output_val[self.output_messages_key or "output"]
if isinstance(output_val, str):
from langchain.schema.messages import AIMessage
return [AIMessage(content=output_val)]
elif isinstance(output_val, BaseMessage):
return [output_val]
elif isinstance(output_val, (list, tuple)):
return list(output_val)
else:
raise ValueError()
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
hist = config["configurable"]["message_history"]
# return only historic messages
if self.history_messages_key:
return hist.messages.copy()
# return all messages
else:
input_val = (
input if not self.input_messages_key else input[self.input_messages_key]
)
return hist.messages.copy() + self._get_input_messages(input_val)
async def _aenter_history(
self, input: Dict[str, Any], config: RunnableConfig
) -> List[BaseMessage]:
return await asyncio.get_running_loop().run_in_executor(
None, self._enter_history, input, config
)
def _exit_history(self, run: Run, config: RunnableConfig) -> None:
hist = config["configurable"]["message_history"]
# Get the input messages
inputs = load(run.inputs)
input_val = inputs[self.input_messages_key or "input"]
input_messages = self._get_input_messages(input_val)
# Get the output messages
output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val)
for m in input_messages + output_messages:
hist.add_message(m)
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs)
# extract session_id
if "session_id" not in config.get("configurable", {}):
example_input = {self.input_messages_key: "foo"}
example_config = {"configurable": {"session_id": "123"}}
raise ValueError(
"session_id_id is required."
" Pass it in as part of the config argument to .invoke() or .stream()"
f"\neg. chain.invoke({example_input}, {example_config})"
)
# attach message_history
session_id = config["configurable"]["session_id"]
config["configurable"]["message_history"] = self.get_session_history(session_id)
return config

View File

@ -14,7 +14,6 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
@ -334,7 +333,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return super().get_output_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.mapper.config_specs
def invoke(

View File

@ -8,7 +8,6 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Union,
cast,
)
@ -55,7 +54,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
runnables: Mapping[str, Runnable[Any, Output]]
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.runnables.values() for spec in step.config_specs
)

View File

@ -308,7 +308,7 @@ class ConfigurableFieldSpec(NamedTuple):
def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec],
) -> Sequence[ConfigurableFieldSpec]:
) -> List[ConfigurableFieldSpec]:
"""Get the unique config specs from a sequence of config specs."""
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
unique: List[ConfigurableFieldSpec] = []

View File

@ -0,0 +1,231 @@
from typing import Any, Callable, Sequence, Union
from langchain.memory import ChatMessageHistory
from langchain.pydantic_v1 import BaseModel
from langchain.schema import AIMessage, BaseMessage, HumanMessage
from langchain.schema.runnable import RunnableConfig, RunnableLambda
from langchain.schema.runnable.history import RunnableWithMessageHistory
def _get_get_session_history() -> Callable[..., ChatMessageHistory]:
chat_history_store = {}
def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory:
if session_id not in chat_history_store:
chat_history_store[session_id] = ChatMessageHistory()
return chat_history_store[session_id]
return get_session_history
def test_input_messages() -> None:
runnable = RunnableLambda(
lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {"configurable": {"session_id": "1"}}
output = with_history.invoke([HumanMessage(content="hello")], config)
assert output == "you said: hello"
output = with_history.invoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye"
def test_input_dict() -> None:
runnable = RunnableLambda(
lambda input: "you said: "
+ "\n".join(
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage)
)
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable, get_session_history, input_messages_key="messages"
)
config: RunnableConfig = {"configurable": {"session_id": "2"}}
output = with_history.invoke({"messages": [HumanMessage(content="hello")]}, config)
assert output == "you said: hello"
output = with_history.invoke(
{"messages": [HumanMessage(content="good bye")]}, config
)
assert output == "you said: hello\ngood bye"
def test_input_dict_with_history_key() -> None:
runnable = RunnableLambda(
lambda input: "you said: "
+ "\n".join(
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)]
+ [input["input"]]
)
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
)
config: RunnableConfig = {"configurable": {"session_id": "3"}}
output = with_history.invoke({"input": "hello"}, config)
assert output == "you said: hello"
output = with_history.invoke({"input": "good bye"}, config)
assert output == "you said: hello\ngood bye"
def test_output_message() -> None:
runnable = RunnableLambda(
lambda input: AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
)
config: RunnableConfig = {"configurable": {"session_id": "4"}}
output = with_history.invoke({"input": "hello"}, config)
assert output == AIMessage(content="you said: hello")
output = with_history.invoke({"input": "good bye"}, config)
assert output == AIMessage(content="you said: hello\ngood bye")
def test_output_messages() -> None:
runnable = RunnableLambda(
lambda input: [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
]
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
)
config: RunnableConfig = {"configurable": {"session_id": "5"}}
output = with_history.invoke({"input": "hello"}, config)
assert output == [AIMessage(content="you said: hello")]
output = with_history.invoke({"input": "good bye"}, config)
assert output == [AIMessage(content="you said: hello\ngood bye")]
def test_output_dict() -> None:
runnable = RunnableLambda(
lambda input: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
]
}
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
output_messages_key="output",
)
config: RunnableConfig = {"configurable": {"session_id": "6"}}
output = with_history.invoke({"input": "hello"}, config)
assert output == {"output": [AIMessage(content="you said: hello")]}
output = with_history.invoke({"input": "good bye"}, config)
assert output == {"output": [AIMessage(content="you said: hello\ngood bye")]}
def test_get_input_schema_input_dict() -> None:
class RunnableWithChatHistoryInput(BaseModel):
input: Union[str, BaseMessage, Sequence[BaseMessage]]
history: Sequence[BaseMessage]
runnable = RunnableLambda(
lambda input: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
]
}
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
output_messages_key="output",
)
assert (
with_history.get_input_schema().schema()
== RunnableWithChatHistoryInput.schema()
)
def test_get_input_schema_input_messages() -> None:
class RunnableWithChatHistoryInput(BaseModel):
__root__: Sequence[BaseMessage]
runnable = RunnableLambda(
lambda messages: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in messages
if isinstance(m, HumanMessage)
]
)
)
]
}
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable, get_session_history, output_messages_key="output"
)
assert (
with_history.get_input_schema().schema()
== RunnableWithChatHistoryInput.schema()
)