diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index a8b338fce92..e169bd9a05d 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -16,7 +16,6 @@ from typing import ( from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.load.load import load from langchain_core.pydantic_v1 import BaseModel -from langchain_core.runnables import RunnableBranch from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.utils import ( @@ -320,17 +319,22 @@ class RunnableWithMessageHistory(RunnableBindingBase): history_chain = RunnablePassthrough.assign( **{messages_key: history_chain} ).with_config(run_name="insert_history") + + runnable_sync: Runnable = runnable.with_listeners(on_end=self._exit_history) + runnable_async: Runnable = runnable.with_alisteners(on_end=self._aexit_history) + + def _call_runnable_sync(_input: Any) -> Runnable: + return runnable_sync + + async def _call_runnable_async(_input: Any) -> Runnable: + return runnable_async + bound: Runnable = ( history_chain - | RunnableBranch( - ( - RunnableLambda( - self._is_not_async, afunc=self._is_async - ).with_config(run_name="RunnableWithMessageHistoryInAsyncMode"), - runnable.with_alisteners(on_end=self._aexit_history), - ), - runnable.with_listeners(on_end=self._exit_history), - ) + | RunnableLambda( + _call_runnable_sync, + _call_runnable_async, + ).with_config(run_name="check_sync_or_async") ).with_config(run_name="RunnableWithMessageHistory") if history_factory_config: @@ -468,7 +472,10 @@ class RunnableWithMessageHistory(RunnableBindingBase): elif isinstance(output_val, (list, tuple)): return list(output_val) else: - raise ValueError() + raise ValueError( + f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. " + f"Got {output_val}." + ) def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]: hist: BaseChatMessageHistory = config["configurable"]["message_history"] diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 3de59e22973..bdec28ebc06 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,5 +1,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import pytest + from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) @@ -8,10 +10,12 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.pydantic_v1 import BaseModel -from langchain_core.runnables.base import RunnableLambda +from langchain_core.runnables import Runnable +from langchain_core.runnables.base import RunnableBinding, RunnableLambda from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain_core.runnables.utils import ConfigurableFieldSpec +from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output +from langchain_core.tracers import Run from tests.unit_tests.pydantic_utils import _schema @@ -724,3 +728,115 @@ def test_ignore_session_id() -> None: _ = with_message_history.invoke("hello") _ = with_message_history.invoke("hello again") assert len(history.messages) == 4 + + +class _RunnableLambdaWithRaiseError(RunnableLambda): + from langchain_core.tracers.root_listeners import AsyncListener + + def with_listeners( + self, + *, + on_start: Optional[ + Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] + ] = None, + on_end: Optional[ + Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] + ] = None, + on_error: Optional[ + Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] + ] = None, + ) -> Runnable[Input, Output]: + from langchain_core.tracers.root_listeners import RootListenersTracer + + def create_tracer(config: RunnableConfig) -> RunnableConfig: + tracer = RootListenersTracer( + config=config, + on_start=on_start, + on_end=on_end, + on_error=on_error, + ) + tracer.raise_error = True + return { + "callbacks": [tracer], + } + + return RunnableBinding( + bound=self, + config_factories=[lambda config: create_tracer(config)], + ) + + def with_alisteners( + self, + *, + on_start: Optional[AsyncListener] = None, + on_end: Optional[AsyncListener] = None, + on_error: Optional[AsyncListener] = None, + ) -> Runnable[Input, Output]: + from langchain_core.tracers.root_listeners import AsyncRootListenersTracer + + def create_tracer(config: RunnableConfig) -> RunnableConfig: + tracer = AsyncRootListenersTracer( + config=config, + on_start=on_start, + on_end=on_end, + on_error=on_error, + ) + tracer.raise_error = True + return { + "callbacks": [tracer], + } + + return RunnableBinding( + bound=self, + config_factories=[lambda config: create_tracer(config)], + ) + + +def test_get_output_messages_no_value_error() -> None: + runnable = _RunnableLambdaWithRaiseError( + lambda messages: "you said: " + + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) + ) + store: Dict = {} + get_session_history = _get_get_session_history(store=store) + with_history = RunnableWithMessageHistory(runnable, get_session_history) + config: RunnableConfig = { + "configurable": {"session_id": "1", "message_history": get_session_history("1")} + } + may_catch_value_error = None + try: + with_history.bound.invoke([HumanMessage(content="hello")], config) + except ValueError as e: + may_catch_value_error = e + assert may_catch_value_error is None + + +def test_get_output_messages_with_value_error() -> None: + illegal_bool_message = False + runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_bool_message) + store: Dict = {} + get_session_history = _get_get_session_history(store=store) + with_history = RunnableWithMessageHistory(runnable, get_session_history) + config: RunnableConfig = { + "configurable": {"session_id": "1", "message_history": get_session_history("1")} + } + + with pytest.raises(ValueError) as excinfo: + with_history.bound.invoke([HumanMessage(content="hello")], config) + excepted = ( + "Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]." + + (" Got {}.".format(illegal_bool_message)) + ) + assert excepted in str(excinfo.value) + + illegal_int_message = 123 + runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message) + with_history = RunnableWithMessageHistory(runnable, get_session_history) + + with pytest.raises(ValueError) as excinfo: + with_history.bound.invoke([HumanMessage(content="hello")], config) + excepted = ( + "Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]." + + (" Got {}.".format(illegal_int_message)) + ) + assert excepted in str(excinfo.value)