diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index f02364bccde..c35cd42de5e 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -25,6 +25,7 @@ from langchain_core.runnables.utils import ( ) if TYPE_CHECKING: + from langchain_core.language_models import LanguageModelLike from langchain_core.messages import BaseMessage from langchain_core.runnables.config import RunnableConfig from langchain_core.tracers.schemas import Run @@ -228,9 +229,12 @@ class RunnableWithMessageHistory(RunnableBindingBase): def __init__( self, - runnable: Runnable[ - MessagesOrDictWithMessages, - Union[str, BaseMessage, MessagesOrDictWithMessages], + runnable: Union[ + Runnable[ + Union[MessagesOrDictWithMessages], + Union[str, BaseMessage, MessagesOrDictWithMessages], + ], + LanguageModelLike, ], get_session_history: GetSessionHistoryCallable, *, @@ -364,10 +368,19 @@ class RunnableWithMessageHistory(RunnableBindingBase): return super_schema def _get_input_messages( - self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]] + self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] ) -> List[BaseMessage]: from langchain_core.messages import BaseMessage + if isinstance(input_val, dict): + if self.input_messages_key: + key = self.input_messages_key + elif len(input_val) == 1: + key = list(input_val.keys())[0] + else: + key = "input" + input_val = input_val[key] + if isinstance(input_val, str): from langchain_core.messages import HumanMessage @@ -388,7 +401,18 @@ class RunnableWithMessageHistory(RunnableBindingBase): from langchain_core.messages import BaseMessage if isinstance(output_val, dict): - output_val = output_val[self.output_messages_key or "output"] + if self.output_messages_key: + key = self.output_messages_key + elif len(output_val) == 1: + key = list(output_val.keys())[0] + else: + key = "output" + # If you are wrapping a chat model directly + # The output is actually this weird generations object + if key not in output_val and "generations" in output_val: + output_val = output_val["generations"][0][0]["message"] + else: + output_val = output_val[key] if isinstance(output_val, str): from langchain_core.messages import AIMessage @@ -407,10 +431,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): if not self.history_messages_key: # return all messages - input_val = ( - input if not self.input_messages_key else input[self.input_messages_key] - ) - messages += self._get_input_messages(input_val) + messages += self._get_input_messages(input) return messages async def _aenter_history( @@ -432,8 +453,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # Get the input messages inputs = load(run.inputs) - input_val = inputs[self.input_messages_key or "input"] - input_messages = self._get_input_messages(input_val) + input_messages = self._get_input_messages(inputs) # If historic messages were prepended to the input messages, remove them to # avoid adding duplicate messages to history. diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 72cfcbf77cd..3db18b9963f 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,6 +1,11 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from langchain_core.callbacks import ( + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.base import RunnableLambda from langchain_core.runnables.config import RunnableConfig @@ -127,6 +132,41 @@ def test_output_message() -> None: assert output == AIMessage(content="you said: hello\ngood bye") +def test_input_messages_output_message() -> None: + class LengthChatModel(BaseChatModel): + """A fake chat model that returns the length of the messages passed in.""" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + return ChatResult( + generations=[ + ChatGeneration(message=AIMessage(content=str(len(messages)))) + ] + ) + + @property + def _llm_type(self) -> str: + return "length-fake-chat-model" + + runnable = LengthChatModel() + get_session_history = _get_get_session_history() + with_history = RunnableWithMessageHistory( + runnable, + get_session_history, + ) + config: RunnableConfig = {"configurable": {"session_id": "4"}} + output = with_history.invoke([HumanMessage(content="hi")], config) + assert output.content == "1" + output = with_history.invoke([HumanMessage(content="hi")], config) + assert output.content == "3" + + def test_output_messages() -> None: runnable = RunnableLambda( lambda input: [