mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
FEATURE: Runnable with message history (#13418)
Add RunnableWithMessageHistory class that can wrap certain runnables and manages chat history for them.
This commit is contained in:
parent
0fc3af8932
commit
2e2114d2d0
396
docs/docs/expression_language/how_to/message_history.ipynb
Normal file
396
docs/docs/expression_language/how_to/message_history.ipynb
Normal file
@ -0,0 +1,396 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6a4becbd-238e-4c1d-a02d-08e61fbc3763",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Add message history (memory)\n",
|
||||
"\n",
|
||||
"The `RunnableWithMessageHistory` let's us add message history to certain types of chains.\n",
|
||||
"\n",
|
||||
"Specifically, it can be used for any Runnable that takes as input one of\n",
|
||||
"* a sequence of `BaseMessage`\n",
|
||||
"* a dict with a key that takes a sequence of `BaseMessage`\n",
|
||||
"* a dict with a key that takes the latest message(s) as a string or sequence of `BaseMessage`, and a separate key that takes historical messages\n",
|
||||
"\n",
|
||||
"And returns as output one of\n",
|
||||
"* a string that can be treated as the contents of an `AIMessage`\n",
|
||||
"* a sequence of `BaseMessage`\n",
|
||||
"* a dict with a key that contains a sequence of `BaseMessage`\n",
|
||||
"\n",
|
||||
"Let's take a look at some examples to see how it works."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6bca45e5-35d9-4603-9ca9-6ac0ce0e35cd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"We'll use Redis to store our chat message histories and Anthropic's claude-2 model so we'll need to install the following dependencies:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "477d04b3-c2b6-4ba5-962f-492c0d625cd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -U langchain redis anthropic"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "93776323-d6b8-4912-bb6a-867c5e655f46",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set your [Anthropic API key](https://console.anthropic.com/):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c7f56f69-d2f1-4a21-990c-b5551eb012fa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"ANTHROPIC_API_KEY\"] = getpass.getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6a0ec9e0-7b1c-4c6f-b570-e61d520b47c6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Start a local Redis Stack server if we don't have an existing Redis deployment to connect to:\n",
|
||||
"```bash\n",
|
||||
"docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "cd6a250e-17fe-4368-a39d-1fe6b2cbde68",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"REDIS_URL = \"redis://localhost:6379/0\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "36f43b87-655c-4f64-aa7b-bd8c1955d8e5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### [LangSmith](/docs/langsmith)\n",
|
||||
"\n",
|
||||
"LangSmith is especially useful for something like message history injection, where it can be hard to otherwise understand what the inputs are to various parts of the chain.\n",
|
||||
"\n",
|
||||
"Note that LangSmith is not needed, but it is helpful.\n",
|
||||
"If you do want to use LangSmith, after you sign up at the link above, make sure to uncoment the below and set your environment variables to start logging traces:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "2afc1556-8da1-4499-ba11-983b66c58b18",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
|
||||
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1a5a632e-ba9e-4488-b586-640ad5494f62",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example: Dict input, message output\n",
|
||||
"\n",
|
||||
"Let's create a simple chain that takes a dict as input and returns a BaseMessage.\n",
|
||||
"\n",
|
||||
"In this case the `\"question\"` key in the input represents our input message, and the `\"history\"` key is where our historical messages will be injected."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "2a150d6f-8878-4950-8634-a608c5faad56",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Optional\n",
|
||||
"\n",
|
||||
"from langchain.chat_models import ChatAnthropic\n",
|
||||
"from langchain.memory.chat_message_histories import RedisChatMessageHistory\n",
|
||||
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
|
||||
"from langchain.schema.chat_history import BaseChatMessageHistory\n",
|
||||
"from langchain.schema.runnable.history import RunnableWithMessageHistory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "3185edba-4eb6-4b32-80c6-577c0d19af97",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\"system\", \"You're an assistant who's good at {ability}\"),\n",
|
||||
" MessagesPlaceholder(variable_name=\"history\"),\n",
|
||||
" (\"human\", \"{question}\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = prompt | ChatAnthropic(model=\"claude-2\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f9d81796-ce61-484c-89e2-6c567d5e54ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Adding message history\n",
|
||||
"\n",
|
||||
"To add message history to our original chain we wrap it in the `RunnableWithMessageHistory` class.\n",
|
||||
"\n",
|
||||
"Crucially, we also need to define a method that takes a session_id string and based on it returns a `BaseChatMessageHistory`. Given the same input, this method should return an equivalent output.\n",
|
||||
"\n",
|
||||
"In this case we'll also want to specify `input_messages_key` (the key to be treated as the latest input message) and `history_messages_key` (the key to add historical messages to)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "ca7c64d8-e138-4ef8-9734-f82076c47d80",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain_with_history = RunnableWithMessageHistory(\n",
|
||||
" chain,\n",
|
||||
" lambda session_id: RedisChatMessageHistory(session_id, url=REDIS_URL),\n",
|
||||
" input_messages_key=\"question\",\n",
|
||||
" history_messages_key=\"history\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "37eefdec-9901-4650-b64c-d3c097ed5f4d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Invoking with config\n",
|
||||
"\n",
|
||||
"Whenever we call our chain with message history, we need to include a config that contains the `session_id`\n",
|
||||
"```python\n",
|
||||
"config={\"configurable\": {\"session_id\": \"<SESSION_ID>\"}}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Given the same configuration, our chain should be pulling from the same chat message history."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "a85bcc22-ca4c-4ad5-9440-f94be7318f3e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=' Cosine is one of the basic trigonometric functions in mathematics. It is defined as the ratio of the adjacent side to the hypotenuse in a right triangle.\\n\\nSome key properties and facts about cosine:\\n\\n- It is denoted by cos(θ), where θ is the angle in a right triangle. \\n\\n- The cosine of an acute angle is always positive. For angles greater than 90 degrees, cosine can be negative.\\n\\n- Cosine is one of the three main trig functions along with sine and tangent.\\n\\n- The cosine of 0 degrees is 1. As the angle increases towards 90 degrees, the cosine value decreases towards 0.\\n\\n- The range of values for cosine is -1 to 1.\\n\\n- The cosine function maps angles in a circle to the x-coordinate on the unit circle.\\n\\n- Cosine is used to find adjacent side lengths in right triangles, and has many other applications in mathematics, physics, engineering and more.\\n\\n- Key cosine identities include: cos(A+B) = cosAcosB − sinAsinB and cos(2A) = cos^2(A) − sin^2(A)\\n\\nSo in summary, cosine is a fundamental trig')"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain_with_history.invoke(\n",
|
||||
" {\"ability\": \"math\", \"question\": \"What does cosine mean?\"},\n",
|
||||
" config={\"configurable\": {\"session_id\": \"foobar\"}},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "ab29abd3-751f-41ce-a1b0-53f6b565e79d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=' The inverse of the cosine function is called the arccosine or inverse cosine, often denoted as cos-1(x) or arccos(x).\\n\\nThe key properties and facts about arccosine:\\n\\n- It is defined as the angle θ between 0 and π radians whose cosine is x. So arccos(x) = θ such that cos(θ) = x.\\n\\n- The range of arccosine is 0 to π radians (0 to 180 degrees).\\n\\n- The domain of arccosine is -1 to 1. \\n\\n- arccos(cos(θ)) = θ for values of θ from 0 to π radians.\\n\\n- arccos(x) is the angle in a right triangle whose adjacent side is x and hypotenuse is 1.\\n\\n- arccos(0) = 90 degrees. As x increases from 0 to 1, arccos(x) decreases from 90 to 0 degrees.\\n\\n- arccos(1) = 0 degrees. arccos(-1) = 180 degrees.\\n\\n- The graph of y = arccos(x) is part of the unit circle, restricted to x')"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain_with_history.invoke(\n",
|
||||
" {\"ability\": \"math\", \"question\": \"What's its inverse\"},\n",
|
||||
" config={\"configurable\": {\"session_id\": \"foobar\"}},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "da3d1feb-b4bb-4624-961c-7db2e1180df7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
":::tip [Langsmith trace](https://smith.langchain.com/public/863a003b-7ca8-4b24-be9e-d63ec13c106e/r)\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "61d5115e-64a1-4ad5-b676-8afd4ef6093e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Looking at the Langsmith trace for the second call, we can see that when constructing the prompt, a \"history\" variable has been injected which is a list of two messages (our first input and first output)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "028cf151-6cd5-4533-b3cf-c8d735554647",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example: messages input, dict output"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "0bb446b5-6251-45fe-a92a-4c6171473c53",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'output_message': AIMessage(content=' Here is a summary of Simone de Beauvoir\\'s views on free will:\\n\\n- De Beauvoir was an existentialist philosopher and believed strongly in the concept of free will. She rejected the idea that human nature or instincts determine behavior.\\n\\n- Instead, de Beauvoir argued that human beings define their own essence or nature through their actions and choices. As she famously wrote, \"One is not born, but rather becomes, a woman.\"\\n\\n- De Beauvoir believed that while individuals are situated in certain cultural contexts and social conditions, they still have agency and the ability to transcend these situations. Freedom comes from choosing one\\'s attitude toward these constraints.\\n\\n- She emphasized the radical freedom and responsibility of the individual. We are \"condemned to be free\" because we cannot escape making choices and taking responsibility for our choices. \\n\\n- De Beauvoir felt that many people evade their freedom and responsibility by adopting rigid mindsets, ideologies, or conforming uncritically to social roles.\\n\\n- She advocated for the recognition of ambiguity in the human condition and warned against the quest for absolute rules that deny freedom and responsibility. Authentic living involves embracing ambiguity.\\n\\nIn summary, de Beauvoir promoted an existential ethics')}"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.schema.messages import HumanMessage\n",
|
||||
"from langchain.schema.runnable import RunnableMap\n",
|
||||
"\n",
|
||||
"chain = RunnableMap({\"output_message\": ChatAnthropic(model=\"claude-2\")})\n",
|
||||
"chain_with_history = RunnableWithMessageHistory(\n",
|
||||
" chain,\n",
|
||||
" lambda session_id: RedisChatMessageHistory(session_id, url=REDIS_URL),\n",
|
||||
" output_messages_key=\"output_message\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain_with_history.invoke(\n",
|
||||
" [HumanMessage(content=\"What did Simone de Beauvoir believe about free will\")],\n",
|
||||
" config={\"configurable\": {\"session_id\": \"baz\"}},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "601ce3ff-aea8-424d-8e54-fd614256af4f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'output_message': AIMessage(content=\" There are many similarities between Simone de Beauvoir's views on free will and those of Jean-Paul Sartre, though some key differences emerge as well:\\n\\nSimilarities with Sartre:\\n\\n- Both were existentialist thinkers who rejected determinism and emphasized human freedom and responsibility.\\n\\n- They agreed that existence precedes essence - there is no predefined human nature that determines who we are.\\n\\n- Individuals must define themselves through their choices and actions. This leads to anxiety but also freedom.\\n\\n- The human condition is characterized by ambiguity and uncertainty, rather than fixed meanings/values.\\n\\n- Both felt that most people evade their freedom through self-deception, conformity, or adopting collective identities/values uncritically.\\n\\nDifferences from Sartre: \\n\\n- Sartre placed more emphasis on the burden and anguish of radical freedom. De Beauvoir focused more on its positive potential.\\n\\n- De Beauvoir critiqued Sartre's premise that human relations are necessarily conflictual. She saw more potential for mutual recognition.\\n\\n- Sartre saw the Other's gaze as a threat to freedom. De Beauvoir put more stress on how the Other's gaze can confirm\")}"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain_with_history.invoke(\n",
|
||||
" [HumanMessage(content=\"How did this compare to Sartre\")],\n",
|
||||
" config={\"configurable\": {\"session_id\": \"baz\"}},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b898d1b1-11e6-4d30-a8dd-cc5e45533611",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
":::tip [LangSmith trace](https://smith.langchain.com/public/f6c3e1d1-a49d-4955-a9fa-c6519df74fa7/r)\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1724292c-01c6-44bb-83e8-9cdb6bf01483",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## More examples\n",
|
||||
"\n",
|
||||
"We could also do any of the below:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fd89240b-5a25-48f8-9568-5c1127f9ffad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from operator import itemgetter\n",
|
||||
"\n",
|
||||
"# messages in, messages out\n",
|
||||
"RunnableWithMessageHistory(\n",
|
||||
" ChatAnthropic(model=\"claude-2\"),\n",
|
||||
" lambda session_id: RedisChatMessageHistory(session_id, url=REDIS_URL),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# dict with single key for all messages in, messages out\n",
|
||||
"RunnableWithMessageHistory(\n",
|
||||
" itemgetter(\"input_messages\") | ChatAnthropic(model=\"claude-2\"),\n",
|
||||
" lambda session_id: RedisChatMessageHistory(session_id, url=REDIS_URL),\n",
|
||||
" input_messages_key=\"input_messages\",\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "poetry-venv",
|
||||
"language": "python",
|
||||
"name": "poetry-venv"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -58,8 +58,6 @@
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import FAISS\n",
|
||||
"\n",
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"\n",
|
||||
"loader = TextLoader(\"../../../extras/modules/state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
|
@ -61,8 +61,6 @@
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import FAISS\n",
|
||||
"\n",
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"\n",
|
||||
"loader = TextLoader(\"../../../extras/modules/state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
|
@ -54,11 +54,10 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.vectorstores.vearch import Vearch\n",
|
||||
"\n",
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"from langchain.vectorstores.vearch import Vearch\n",
|
||||
"from transformers import AutoModel, AutoTokenizer\n",
|
||||
"\n",
|
||||
"# repalce to your local model path\n",
|
||||
|
@ -1051,6 +1051,14 @@
|
||||
":::"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fdf6c7e0-84f8-4747-b2ae-e84315152bd9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we've gone over how to add chain logic for incorporating historical outputs. But how do we actually store and retrieve historical outputs for different sessions? For that check out the LCEL [How to add message history (memory)](/docs/expression_language/how_to/message_history) page."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "580e18de-132d-4009-ba67-4aaf2c7717a2",
|
||||
|
@ -85,6 +85,9 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
variable_name: str
|
||||
"""Name of variable to use as messages."""
|
||||
|
||||
def __init__(self, variable_name: str, **kwargs: Any):
|
||||
return super().__init__(variable_name=variable_name, **kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
||||
|
@ -42,7 +42,6 @@ if TYPE_CHECKING:
|
||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||
)
|
||||
|
||||
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import BaseModel, Field, create_model
|
||||
@ -298,7 +297,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
"""List configurable fields for this runnable."""
|
||||
return []
|
||||
|
||||
@ -1357,7 +1356,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
return self.last.get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for step in self.steps for spec in step.config_specs
|
||||
)
|
||||
@ -1885,7 +1884,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for step in self.steps.values() for spec in step.config_specs
|
||||
)
|
||||
@ -2591,7 +2590,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
@classmethod
|
||||
@ -2763,7 +2762,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
return self.bound.get_output_schema(merge_configs(self.config, config))
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
@classmethod
|
||||
|
@ -147,7 +147,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec
|
||||
for step in (
|
||||
|
@ -209,7 +209,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
fields: Dict[str, AnyConfigurableField]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
[
|
||||
ConfigurableFieldSpec(
|
||||
@ -300,7 +300,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
default_key: str = "default"
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
with _enums_for_spec_lock:
|
||||
if which_enum := _enums_for_spec.get(self.which):
|
||||
pass
|
||||
|
@ -112,7 +112,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
return self.runnable.get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec
|
||||
for step in [self.runnable, *self.fallbacks]
|
||||
|
288
libs/langchain/langchain/schema/runnable/history.py
Normal file
288
libs/langchain/langchain/schema/runnable/history.py
Normal file
@ -0,0 +1,288 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.load import load
|
||||
from langchain.pydantic_v1 import BaseModel, create_model
|
||||
from langchain.schema.chat_history import BaseChatMessageHistory
|
||||
from langchain.schema.runnable.base import Runnable, RunnableBindingBase, RunnableLambda
|
||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||
from langchain.schema.runnable.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
|
||||
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
|
||||
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
||||
|
||||
|
||||
class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
"""A runnable that manages chat message history for another runnable.
|
||||
|
||||
Base runnable must have inputs and outputs that can be converted to a list of
|
||||
BaseMessages.
|
||||
|
||||
RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.:
|
||||
``{"configurable": {"session_id": "<SESSION_ID>"}}``
|
||||
|
||||
Example (dict input):
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.memory.chat_message_histories import RedisChatMessageHistory
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema.runnable.history import RunnableWithMessageHistory
|
||||
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "You're an assistant who's good at {ability}"),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", "{question}"),
|
||||
])
|
||||
|
||||
chain = prompt | ChatAnthropic(model="claude-2")
|
||||
|
||||
chain_with_history = RunnableWithMessageHistory(
|
||||
chain,
|
||||
RedisChatMessageHistory,
|
||||
input_messages_key="question",
|
||||
history_messages_key="history",
|
||||
)
|
||||
|
||||
chain_with_history.invoke(
|
||||
{"ability": "math", "question": "What does cosine mean?"},
|
||||
config={"configurable": {"session_id": "foo"}}
|
||||
)
|
||||
# -> "Cosine is ..."
|
||||
chain_with_history.invoke(
|
||||
{"ability": "math", "question": "What's its inverse"},
|
||||
config={"configurable": {"session_id": "foo"}}
|
||||
)
|
||||
# -> "The inverse of cosine is called arccosine ..."
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
get_session_history: GetSessionHistoryCallable
|
||||
input_messages_key: Optional[str] = None
|
||||
output_messages_key: Optional[str] = None
|
||||
history_messages_key: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Runnable[
|
||||
MessagesOrDictWithMessages,
|
||||
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
||||
],
|
||||
get_session_history: GetSessionHistoryCallable,
|
||||
*,
|
||||
input_messages_key: Optional[str] = None,
|
||||
output_messages_key: Optional[str] = None,
|
||||
history_messages_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize RunnableWithMessageHistory.
|
||||
|
||||
Args:
|
||||
runnable: The base Runnable to be wrapped.
|
||||
|
||||
Must take as input one of:
|
||||
- A sequence of BaseMessages
|
||||
- A dict with one key for all messages
|
||||
- A dict with one key for the current input string/message(s) and
|
||||
a separate key for historical messages. If the input key points
|
||||
to a string, it will be treated as a HumanMessage in history.
|
||||
|
||||
Must return as output one of:
|
||||
- A string which can be treated as an AIMessage
|
||||
- A BaseMessage or sequence of BaseMessages
|
||||
- A dict with a key for a BaseMessage or sequence of BaseMessages
|
||||
|
||||
get_session_history: Function that returns a new BaseChatMessageHistory
|
||||
given a session id. Should take a single
|
||||
positional argument `session_id` which is a string and a named argument
|
||||
`user_id` which can be a string or None. e.g.:
|
||||
|
||||
```python
|
||||
def get_session_history(
|
||||
session_id: str,
|
||||
*,
|
||||
user_id: Optional[str]=None
|
||||
) -> BaseChatMessageHistory:
|
||||
...
|
||||
```
|
||||
|
||||
input_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input.
|
||||
output_messages_key: Must be specified if the base runnable returns a dict
|
||||
as output.
|
||||
history_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input and expects a separate key for historical messages.
|
||||
**kwargs: Arbitrary additional kwargs to pass to parent class
|
||||
``RunnableBindingBase`` init.
|
||||
""" # noqa: E501
|
||||
history_chain: Runnable = RunnableLambda(
|
||||
self._enter_history, self._aenter_history
|
||||
).with_config(run_name="load_history")
|
||||
messages_key = history_messages_key or input_messages_key
|
||||
if messages_key:
|
||||
history_chain = RunnablePassthrough.assign(
|
||||
**{messages_key: history_chain}
|
||||
).with_config(run_name="insert_history")
|
||||
bound = (
|
||||
history_chain | runnable.with_listeners(on_end=self._exit_history)
|
||||
).with_config(run_name="RunnableWithMessageHistory")
|
||||
super().__init__(
|
||||
get_session_history=get_session_history,
|
||||
input_messages_key=input_messages_key,
|
||||
output_messages_key=output_messages_key,
|
||||
bound=bound,
|
||||
history_messages_key=history_messages_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
super().config_specs
|
||||
+ [
|
||||
ConfigurableFieldSpec(
|
||||
id="session_id",
|
||||
annotation=str,
|
||||
name="Session ID",
|
||||
description="Unique identifier for a session.",
|
||||
default="",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
super_schema = super().get_input_schema(config)
|
||||
if super_schema.__custom_root_type__ is not None:
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
fields: Dict = {}
|
||||
if self.input_messages_key and self.history_messages_key:
|
||||
fields[self.input_messages_key] = (
|
||||
Union[str, BaseMessage, Sequence[BaseMessage]],
|
||||
...,
|
||||
)
|
||||
elif self.input_messages_key:
|
||||
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
|
||||
else:
|
||||
fields["__root__"] = (Sequence[BaseMessage], ...)
|
||||
if self.history_messages_key:
|
||||
fields[self.history_messages_key] = (Sequence[BaseMessage], ...)
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableWithChatHistoryInput",
|
||||
**fields,
|
||||
)
|
||||
else:
|
||||
return super_schema
|
||||
|
||||
def _get_input_messages(
|
||||
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
|
||||
) -> List[BaseMessage]:
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
if isinstance(input_val, str):
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
return [HumanMessage(content=input_val)]
|
||||
elif isinstance(input_val, BaseMessage):
|
||||
return [input_val]
|
||||
elif isinstance(input_val, (list, tuple)):
|
||||
return list(input_val)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
|
||||
f"Got {input_val}."
|
||||
)
|
||||
|
||||
def _get_output_messages(
|
||||
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||
) -> List[BaseMessage]:
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
if isinstance(output_val, dict):
|
||||
output_val = output_val[self.output_messages_key or "output"]
|
||||
|
||||
if isinstance(output_val, str):
|
||||
from langchain.schema.messages import AIMessage
|
||||
|
||||
return [AIMessage(content=output_val)]
|
||||
elif isinstance(output_val, BaseMessage):
|
||||
return [output_val]
|
||||
elif isinstance(output_val, (list, tuple)):
|
||||
return list(output_val)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
|
||||
hist = config["configurable"]["message_history"]
|
||||
# return only historic messages
|
||||
if self.history_messages_key:
|
||||
return hist.messages.copy()
|
||||
# return all messages
|
||||
else:
|
||||
input_val = (
|
||||
input if not self.input_messages_key else input[self.input_messages_key]
|
||||
)
|
||||
return hist.messages.copy() + self._get_input_messages(input_val)
|
||||
|
||||
async def _aenter_history(
|
||||
self, input: Dict[str, Any], config: RunnableConfig
|
||||
) -> List[BaseMessage]:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self._enter_history, input, config
|
||||
)
|
||||
|
||||
def _exit_history(self, run: Run, config: RunnableConfig) -> None:
|
||||
hist = config["configurable"]["message_history"]
|
||||
|
||||
# Get the input messages
|
||||
inputs = load(run.inputs)
|
||||
input_val = inputs[self.input_messages_key or "input"]
|
||||
input_messages = self._get_input_messages(input_val)
|
||||
|
||||
# Get the output messages
|
||||
output_val = load(run.outputs)
|
||||
output_messages = self._get_output_messages(output_val)
|
||||
|
||||
for m in input_messages + output_messages:
|
||||
hist.add_message(m)
|
||||
|
||||
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
config = super()._merge_configs(*configs)
|
||||
# extract session_id
|
||||
if "session_id" not in config.get("configurable", {}):
|
||||
example_input = {self.input_messages_key: "foo"}
|
||||
example_config = {"configurable": {"session_id": "123"}}
|
||||
raise ValueError(
|
||||
"session_id_id is required."
|
||||
" Pass it in as part of the config argument to .invoke() or .stream()"
|
||||
f"\neg. chain.invoke({example_input}, {example_config})"
|
||||
)
|
||||
# attach message_history
|
||||
session_id = config["configurable"]["session_id"]
|
||||
config["configurable"]["message_history"] = self.get_session_history(session_id)
|
||||
return config
|
@ -14,7 +14,6 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
@ -334,7 +333,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
return super().get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.mapper.config_specs
|
||||
|
||||
def invoke(
|
||||
|
@ -8,7 +8,6 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -55,7 +54,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
runnables: Mapping[str, Runnable[Any, Output]]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for step in self.runnables.values() for spec in step.config_specs
|
||||
)
|
||||
|
@ -308,7 +308,7 @@ class ConfigurableFieldSpec(NamedTuple):
|
||||
|
||||
def get_unique_config_specs(
|
||||
specs: Iterable[ConfigurableFieldSpec],
|
||||
) -> Sequence[ConfigurableFieldSpec]:
|
||||
) -> List[ConfigurableFieldSpec]:
|
||||
"""Get the unique config specs from a sequence of config specs."""
|
||||
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
|
||||
unique: List[ConfigurableFieldSpec] = []
|
||||
|
231
libs/langchain/tests/unit_tests/schema/runnable/test_history.py
Normal file
231
libs/langchain/tests/unit_tests/schema/runnable/test_history.py
Normal file
@ -0,0 +1,231 @@
|
||||
from typing import Any, Callable, Sequence, Union
|
||||
|
||||
from langchain.memory import ChatMessageHistory
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain.schema.runnable import RunnableConfig, RunnableLambda
|
||||
from langchain.schema.runnable.history import RunnableWithMessageHistory
|
||||
|
||||
|
||||
def _get_get_session_history() -> Callable[..., ChatMessageHistory]:
|
||||
chat_history_store = {}
|
||||
|
||||
def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory:
|
||||
if session_id not in chat_history_store:
|
||||
chat_history_store[session_id] = ChatMessageHistory()
|
||||
return chat_history_store[session_id]
|
||||
|
||||
return get_session_history
|
||||
|
||||
|
||||
def test_input_messages() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda messages: "you said: "
|
||||
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
config: RunnableConfig = {"configurable": {"session_id": "1"}}
|
||||
output = with_history.invoke([HumanMessage(content="hello")], config)
|
||||
assert output == "you said: hello"
|
||||
output = with_history.invoke([HumanMessage(content="good bye")], config)
|
||||
assert output == "you said: hello\ngood bye"
|
||||
|
||||
|
||||
def test_input_dict() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: "you said: "
|
||||
+ "\n".join(
|
||||
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage)
|
||||
)
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable, get_session_history, input_messages_key="messages"
|
||||
)
|
||||
config: RunnableConfig = {"configurable": {"session_id": "2"}}
|
||||
output = with_history.invoke({"messages": [HumanMessage(content="hello")]}, config)
|
||||
assert output == "you said: hello"
|
||||
output = with_history.invoke(
|
||||
{"messages": [HumanMessage(content="good bye")]}, config
|
||||
)
|
||||
assert output == "you said: hello\ngood bye"
|
||||
|
||||
|
||||
def test_input_dict_with_history_key() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: "you said: "
|
||||
+ "\n".join(
|
||||
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)]
|
||||
+ [input["input"]]
|
||||
)
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable,
|
||||
get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="history",
|
||||
)
|
||||
config: RunnableConfig = {"configurable": {"session_id": "3"}}
|
||||
output = with_history.invoke({"input": "hello"}, config)
|
||||
assert output == "you said: hello"
|
||||
output = with_history.invoke({"input": "good bye"}, config)
|
||||
assert output == "you said: hello\ngood bye"
|
||||
|
||||
|
||||
def test_output_message() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
)
|
||||
)
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable,
|
||||
get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="history",
|
||||
)
|
||||
config: RunnableConfig = {"configurable": {"session_id": "4"}}
|
||||
output = with_history.invoke({"input": "hello"}, config)
|
||||
assert output == AIMessage(content="you said: hello")
|
||||
output = with_history.invoke({"input": "good bye"}, config)
|
||||
assert output == AIMessage(content="you said: hello\ngood bye")
|
||||
|
||||
|
||||
def test_output_messages() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable,
|
||||
get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="history",
|
||||
)
|
||||
config: RunnableConfig = {"configurable": {"session_id": "5"}}
|
||||
output = with_history.invoke({"input": "hello"}, config)
|
||||
assert output == [AIMessage(content="you said: hello")]
|
||||
output = with_history.invoke({"input": "good bye"}, config)
|
||||
assert output == [AIMessage(content="you said: hello\ngood bye")]
|
||||
|
||||
|
||||
def test_output_dict() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: {
|
||||
"output": [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable,
|
||||
get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="history",
|
||||
output_messages_key="output",
|
||||
)
|
||||
config: RunnableConfig = {"configurable": {"session_id": "6"}}
|
||||
output = with_history.invoke({"input": "hello"}, config)
|
||||
assert output == {"output": [AIMessage(content="you said: hello")]}
|
||||
output = with_history.invoke({"input": "good bye"}, config)
|
||||
assert output == {"output": [AIMessage(content="you said: hello\ngood bye")]}
|
||||
|
||||
|
||||
def test_get_input_schema_input_dict() -> None:
|
||||
class RunnableWithChatHistoryInput(BaseModel):
|
||||
input: Union[str, BaseMessage, Sequence[BaseMessage]]
|
||||
history: Sequence[BaseMessage]
|
||||
|
||||
runnable = RunnableLambda(
|
||||
lambda input: {
|
||||
"output": [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable,
|
||||
get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="history",
|
||||
output_messages_key="output",
|
||||
)
|
||||
assert (
|
||||
with_history.get_input_schema().schema()
|
||||
== RunnableWithChatHistoryInput.schema()
|
||||
)
|
||||
|
||||
|
||||
def test_get_input_schema_input_messages() -> None:
|
||||
class RunnableWithChatHistoryInput(BaseModel):
|
||||
__root__: Sequence[BaseMessage]
|
||||
|
||||
runnable = RunnableLambda(
|
||||
lambda messages: {
|
||||
"output": [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in messages
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable, get_session_history, output_messages_key="output"
|
||||
)
|
||||
assert (
|
||||
with_history.get_input_schema().schema()
|
||||
== RunnableWithChatHistoryInput.schema()
|
||||
)
|
Loading…
Reference in New Issue
Block a user