From 07823cd41c95ef1594f9075f70e9f7a633433842 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Fri, 21 Mar 2025 17:08:26 -0400 Subject: [PATCH] core[patch]: optimize trim_messages (#30327) Refactored w/ Claude Up to 20x speedup! (with theoretical max improvement of `O(n / log n)`) --- libs/core/langchain_core/messages/utils.py | 168 ++++++++++++------ .../tests/unit_tests/messages/test_utils.py | 8 +- 2 files changed, 125 insertions(+), 51 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 529f2a5c46e..ad6627888b4 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -824,10 +824,14 @@ def trim_messages( AIMessage( [{"type": "text", "text": "This is the FIRST 4 token block."}], id="second"), ] """ # noqa: E501 + # Validate arguments if start_on and strategy == "first": - raise ValueError + msg = "start_on parameter is only valid with strategy='last'" + raise ValueError(msg) if include_system and strategy == "first": - raise ValueError + msg = "include_system parameter is only valid with strategy='last'" + raise ValueError(msg) + messages = convert_to_messages(messages) if hasattr(token_counter, "get_num_tokens_from_messages"): list_token_counter = token_counter.get_num_tokens_from_messages @@ -1232,15 +1236,40 @@ def _first_max_tokens( messages = list(messages) if not messages: return messages - idx = 0 - for i in range(len(messages)): - if token_counter(messages[:-i] if i else messages) <= max_tokens: - idx = len(messages) - i + + # Check if all messages already fit within token limit + if token_counter(messages) <= max_tokens: + # When all messages fit, only apply end_on filtering if needed + if end_on: + for _ in range(len(messages)): + if not _is_message_type(messages[-1], end_on): + messages.pop() + else: + break + return messages + + # Use binary search to find the maximum number of messages within token limit + left, right = 0, len(messages) + max_iterations = len(messages).bit_length() + for _ in range(max_iterations): + if left >= right: break - if partial_strategy and (idx < len(messages) - 1 or idx == 0): + mid = (left + right + 1) // 2 + if token_counter(messages[:mid]) <= max_tokens: + left = mid + idx = mid + else: + right = mid - 1 + + # idx now contains the maximum number of complete messages we can include + idx = left + + if partial_strategy and idx < len(messages): included_partial = False + copied = False if isinstance(messages[idx].content, list): excluded = messages[idx].model_copy(deep=True) + copied = True num_block = len(excluded.content) if partial_strategy == "last": excluded.content = list(reversed(excluded.content)) @@ -1254,41 +1283,59 @@ def _first_max_tokens( if included_partial and partial_strategy == "last": excluded.content = list(reversed(excluded.content)) if not included_partial: - excluded = messages[idx].model_copy(deep=True) - if isinstance(excluded.content, list) and any( - isinstance(block, str) or block["type"] == "text" - for block in messages[idx].content - ): - text_block = next( - block - for block in messages[idx].content - if isinstance(block, str) or block["type"] == "text" - ) - text = ( - text_block["text"] if isinstance(text_block, dict) else text_block - ) - elif isinstance(excluded.content, str): + if not copied: + excluded = messages[idx].model_copy(deep=True) + copied = True + + # Extract text content efficiently + text = None + if isinstance(excluded.content, str): text = excluded.content - else: - text = None - if text: - split_texts = text_splitter(text) - num_splits = len(split_texts) - if partial_strategy == "last": - split_texts = list(reversed(split_texts)) - for _ in range(num_splits - 1): - split_texts.pop() - excluded.content = "".join(split_texts) - if token_counter(messages[:idx] + [excluded]) <= max_tokens: - if partial_strategy == "last": - excluded.content = "".join(reversed(split_texts)) - messages = messages[:idx] + [excluded] - idx += 1 + elif isinstance(excluded.content, list) and excluded.content: + for block in excluded.content: + if isinstance(block, str): + text = block + break + elif isinstance(block, dict) and block.get("type") == "text": + text = block.get("text") break + if text: + if not copied: + excluded = excluded.model_copy(deep=True) + + split_texts = text_splitter(text) + base_message_count = token_counter(messages[:idx]) + if partial_strategy == "last": + split_texts = list(reversed(split_texts)) + + # Binary search for the maximum number of splits we can include + left, right = 0, len(split_texts) + max_iterations = len(split_texts).bit_length() + for _ in range(max_iterations): + if left >= right: + break + mid = (left + right + 1) // 2 + excluded.content = "".join(split_texts[:mid]) + if base_message_count + token_counter([excluded]) <= max_tokens: + left = mid + else: + right = mid - 1 + + if left > 0: + content_splits = split_texts[:left] + if partial_strategy == "last": + content_splits = list(reversed(content_splits)) + excluded.content = "".join(content_splits) + messages = messages[:idx] + [excluded] + idx += 1 + if end_on: - while idx > 0 and not _is_message_type(messages[idx - 1], end_on): - idx -= 1 + for _ in range(idx): + if idx > 0 and not _is_message_type(messages[idx - 1], end_on): + idx -= 1 + else: + break return messages[:idx] @@ -1311,24 +1358,45 @@ def _last_max_tokens( messages = list(messages) if len(messages) == 0: return [] - if end_on: - while messages and not _is_message_type(messages[-1], end_on): - messages.pop() - swapped_system = include_system and isinstance(messages[0], SystemMessage) - reversed_ = messages[:1] + messages[1:][::-1] if swapped_system else messages[::-1] - reversed_ = _first_max_tokens( - reversed_, - max_tokens=max_tokens, + # Filter out messages after end_on type + if end_on: + for _ in range(len(messages)): + if not _is_message_type(messages[-1], end_on): + messages.pop() + else: + break + + # Handle system message preservation + system_message = None + if include_system and len(messages) > 0 and isinstance(messages[0], SystemMessage): + system_message = messages[0] + messages = messages[1:] + + # Reverse messages to use _first_max_tokens with reversed logic + reversed_messages = messages[::-1] + + # Calculate remaining tokens after accounting for system message if present + remaining_tokens = max_tokens + if system_message: + system_tokens = token_counter([system_message]) + remaining_tokens = max(0, max_tokens - system_tokens) + + reversed_result = _first_max_tokens( + reversed_messages, + max_tokens=remaining_tokens, token_counter=token_counter, text_splitter=text_splitter, partial_strategy="last" if allow_partial else None, end_on=start_on, ) - if swapped_system: - return reversed_[:1] + reversed_[1:][::-1] - else: - return reversed_[::-1] + + # Re-reverse the messages and add back the system message if needed + result = reversed_result[::-1] + if system_message: + result = [system_message] + result + + return result _MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = { diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index a10c952d4fc..6e3c608939d 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -455,6 +455,7 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int: def test_trim_messages_partial_text_splitting() -> None: messages = [HumanMessage(content="This is a long message that needs trimming")] + messages_copy = [m.model_copy(deep=True) for m in messages] def count_characters(msgs: list[BaseMessage]) -> int: return sum(len(m.content) if isinstance(m.content, str) else 0 for m in msgs) @@ -474,6 +475,7 @@ def test_trim_messages_partial_text_splitting() -> None: assert len(result) == 1 assert result[0].content == "This is a " # First 10 characters + assert messages == messages_copy def test_trim_messages_mixed_content_with_partial() -> None: @@ -485,6 +487,7 @@ def test_trim_messages_mixed_content_with_partial() -> None: ] ) ] + messages_copy = [m.model_copy(deep=True) for m in messages] # Count total length of all text parts def count_text_length(msgs: list[BaseMessage]) -> int: @@ -509,6 +512,7 @@ def test_trim_messages_mixed_content_with_partial() -> None: assert len(result) == 1 assert len(result[0].content) == 1 assert result[0].content[0]["text"] == "First part of text." + assert messages == messages_copy def test_trim_messages_exact_token_boundary() -> None: @@ -535,6 +539,7 @@ def test_trim_messages_exact_token_boundary() -> None: strategy="first", ) assert len(result2) == 2 + assert result2 == messages def test_trim_messages_start_on_with_allow_partial() -> None: @@ -543,7 +548,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None: AIMessage(content="AI response"), HumanMessage(content="Second human message"), ] - + messages_copy = [m.model_copy(deep=True) for m in messages] result = trim_messages( messages, max_tokens=20, @@ -555,6 +560,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None: assert len(result) == 1 assert result[0].content == "Second human message" + assert messages == messages_copy class FakeTokenCountingModel(FakeChatModel):