diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 9917b2340a9..fdc40e55835 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -372,6 +372,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): ) -> List[BaseMessage]: from langchain_core.messages import BaseMessage + # If dictionary, try to pluck the single key representing messages if isinstance(input_val, dict): if self.input_messages_key: key = self.input_messages_key @@ -381,13 +382,25 @@ class RunnableWithMessageHistory(RunnableBindingBase): key = "input" input_val = input_val[key] + # If value is a string, convert to a human message if isinstance(input_val, str): from langchain_core.messages import HumanMessage return [HumanMessage(content=input_val)] + # If value is a single message, convert to a list elif isinstance(input_val, BaseMessage): return [input_val] + # If value is a list or tuple... elif isinstance(input_val, (list, tuple)): + # Handle empty case + if len(input_val) == 0: + return list(input_val) + # If is a list of list, then return the first value + # This occurs for chat models - since we batch inputs + if isinstance(input_val[0], list): + if len(input_val) != 1: + raise ValueError() + return input_val[0] return list(input_val) else: raise ValueError( @@ -400,6 +413,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): ) -> List[BaseMessage]: from langchain_core.messages import BaseMessage + # If dictionary, try to pluck the single key representing messages if isinstance(output_val, dict): if self.output_messages_key: key = self.output_messages_key @@ -418,6 +432,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): from langchain_core.messages import AIMessage return [AIMessage(content=output_val)] + # If value is a single message, convert to a list elif isinstance(output_val, BaseMessage): return [output_val] elif isinstance(output_val, (list, tuple)): @@ -431,7 +446,10 @@ class RunnableWithMessageHistory(RunnableBindingBase): if not self.history_messages_key: # return all messages - messages += self._get_input_messages(input) + input_val = ( + input if not self.input_messages_key else input[self.input_messages_key] + ) + messages += self._get_input_messages(input_val) return messages async def _aenter_history( @@ -454,7 +472,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): # Get the input messages inputs = load(run.inputs) input_messages = self._get_input_messages(inputs) - # If historic messages were prepended to the input messages, remove them to # avoid adding duplicate messages to history. if not self.history_messages_key: @@ -466,23 +483,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): output_messages = self._get_output_messages(output_val) hist.add_messages(input_messages + output_messages) - async def _aexit_history(self, run: Run, config: RunnableConfig) -> None: - hist: BaseChatMessageHistory = config["configurable"]["message_history"] - - # Get the input messages - inputs = load(run.inputs) - input_messages = self._get_input_messages(inputs) - - # If historic messages were prepended to the input messages, remove them to - # avoid adding duplicate messages to history. - if not self.history_messages_key: - historic_messages = config["configurable"]["message_history"].messages - input_messages = input_messages[len(historic_messages) :] - - # Get the output messages - output_val = load(run.outputs) - output_messages = self._get_output_messages(output_val) - await hist.aadd_messages(input_messages + output_messages) def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: config = super()._merge_configs(*configs) expected_keys = [field_spec.id for field_spec in self.history_factory_config] diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 22ded25cdda..f14e87b9e6c 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,7 +1,5 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union -from langchain_core.pydantic_v1 import Field - from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) @@ -9,7 +7,7 @@ from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.outputs import ChatGeneration, ChatResult -from langchain_core.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables.base import RunnableLambda from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.history import RunnableWithMessageHistory