This commit is contained in:
Eugene Yurtsev
2024-06-04 17:13:59 -04:00
parent 7106bc8125
commit bb65e49630
2 changed files with 20 additions and 22 deletions

View File

@@ -372,6 +372,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
) -> List[BaseMessage]:
from langchain_core.messages import BaseMessage
# If dictionary, try to pluck the single key representing messages
if isinstance(input_val, dict):
if self.input_messages_key:
key = self.input_messages_key
@@ -381,13 +382,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
key = "input"
input_val = input_val[key]
# If value is a string, convert to a human message
if isinstance(input_val, str):
from langchain_core.messages import HumanMessage
return [HumanMessage(content=input_val)]
# If value is a single message, convert to a list
elif isinstance(input_val, BaseMessage):
return [input_val]
# If value is a list or tuple...
elif isinstance(input_val, (list, tuple)):
# Handle empty case
if len(input_val) == 0:
return list(input_val)
# If is a list of list, then return the first value
# This occurs for chat models - since we batch inputs
if isinstance(input_val[0], list):
if len(input_val) != 1:
raise ValueError()
return input_val[0]
return list(input_val)
else:
raise ValueError(
@@ -400,6 +413,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
) -> List[BaseMessage]:
from langchain_core.messages import BaseMessage
# If dictionary, try to pluck the single key representing messages
if isinstance(output_val, dict):
if self.output_messages_key:
key = self.output_messages_key
@@ -418,6 +432,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
from langchain_core.messages import AIMessage
return [AIMessage(content=output_val)]
# If value is a single message, convert to a list
elif isinstance(output_val, BaseMessage):
return [output_val]
elif isinstance(output_val, (list, tuple)):
@@ -431,7 +446,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
if not self.history_messages_key:
# return all messages
messages += self._get_input_messages(input)
input_val = (
input if not self.input_messages_key else input[self.input_messages_key]
)
messages += self._get_input_messages(input_val)
return messages
async def _aenter_history(
@@ -454,7 +472,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
# Get the input messages
inputs = load(run.inputs)
input_messages = self._get_input_messages(inputs)
# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
if not self.history_messages_key:
@@ -466,23 +483,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
output_messages = self._get_output_messages(output_val)
hist.add_messages(input_messages + output_messages)
async def _aexit_history(self, run: Run, config: RunnableConfig) -> None:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
# Get the input messages
inputs = load(run.inputs)
input_messages = self._get_input_messages(inputs)
# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
if not self.history_messages_key:
historic_messages = config["configurable"]["message_history"].messages
input_messages = input_messages[len(historic_messages) :]
# Get the output messages
output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val)
await hist.aadd_messages(input_messages + output_messages)
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs)
expected_keys = [field_spec.id for field_spec in self.history_factory_config]

View File

@@ -1,7 +1,5 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from langchain_core.pydantic_v1 import Field
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
@@ -9,7 +7,7 @@ from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory