diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 98cdb71f6da..b3db8b2104f 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -819,6 +819,7 @@ def trim_messages( def list_token_counter(messages: Sequence[BaseMessage]) -> int: return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc] + else: list_token_counter = token_counter # type: ignore[assignment] else: @@ -956,6 +957,8 @@ def _last_max_tokens( ] = None, ) -> list[BaseMessage]: messages = list(messages) + if len(messages) == 0: + return [] if end_on: while messages and not _is_message_type(messages[-1], end_on): messages.pop() diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 15539a19e3d..f994ace071c 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -332,6 +332,19 @@ def test_trim_messages_allow_partial_text_splitter() -> None: assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY +def test_trim_messages_include_system_strategy_last_empty_messages() -> None: + expected: list[BaseMessage] = [] + + actual = trim_messages( + max_tokens=10, + token_counter=dummy_token_counter, + strategy="last", + include_system=True, + ).invoke([]) + + assert actual == expected + + def test_trim_messages_invoke() -> None: actual = trim_messages(max_tokens=10, token_counter=dummy_token_counter).invoke( _MESSAGES_TO_TRIM