core[patch]: optimize trim_messages (#30327)

Refactored w/ Claude

Up to 20x speedup! (with theoretical max improvement of `O(n / log n)`)
This commit is contained in:
Vadym Barda 2025-03-21 17:08:26 -04:00 committed by GitHub
parent b78ae7817e
commit 07823cd41c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 125 additions and 51 deletions

View File

@ -824,10 +824,14 @@ def trim_messages(
AIMessage( [{"type": "text", "text": "This is the FIRST 4 token block."}], id="second"),
]
""" # noqa: E501
# Validate arguments
if start_on and strategy == "first":
raise ValueError
msg = "start_on parameter is only valid with strategy='last'"
raise ValueError(msg)
if include_system and strategy == "first":
raise ValueError
msg = "include_system parameter is only valid with strategy='last'"
raise ValueError(msg)
messages = convert_to_messages(messages)
if hasattr(token_counter, "get_num_tokens_from_messages"):
list_token_counter = token_counter.get_num_tokens_from_messages
@ -1232,15 +1236,40 @@ def _first_max_tokens(
messages = list(messages)
if not messages:
return messages
idx = 0
for i in range(len(messages)):
if token_counter(messages[:-i] if i else messages) <= max_tokens:
idx = len(messages) - i
# Check if all messages already fit within token limit
if token_counter(messages) <= max_tokens:
# When all messages fit, only apply end_on filtering if needed
if end_on:
for _ in range(len(messages)):
if not _is_message_type(messages[-1], end_on):
messages.pop()
else:
break
return messages
# Use binary search to find the maximum number of messages within token limit
left, right = 0, len(messages)
max_iterations = len(messages).bit_length()
for _ in range(max_iterations):
if left >= right:
break
if partial_strategy and (idx < len(messages) - 1 or idx == 0):
mid = (left + right + 1) // 2
if token_counter(messages[:mid]) <= max_tokens:
left = mid
idx = mid
else:
right = mid - 1
# idx now contains the maximum number of complete messages we can include
idx = left
if partial_strategy and idx < len(messages):
included_partial = False
copied = False
if isinstance(messages[idx].content, list):
excluded = messages[idx].model_copy(deep=True)
copied = True
num_block = len(excluded.content)
if partial_strategy == "last":
excluded.content = list(reversed(excluded.content))
@ -1254,41 +1283,59 @@ def _first_max_tokens(
if included_partial and partial_strategy == "last":
excluded.content = list(reversed(excluded.content))
if not included_partial:
excluded = messages[idx].model_copy(deep=True)
if isinstance(excluded.content, list) and any(
isinstance(block, str) or block["type"] == "text"
for block in messages[idx].content
):
text_block = next(
block
for block in messages[idx].content
if isinstance(block, str) or block["type"] == "text"
)
text = (
text_block["text"] if isinstance(text_block, dict) else text_block
)
elif isinstance(excluded.content, str):
if not copied:
excluded = messages[idx].model_copy(deep=True)
copied = True
# Extract text content efficiently
text = None
if isinstance(excluded.content, str):
text = excluded.content
else:
text = None
if text:
split_texts = text_splitter(text)
num_splits = len(split_texts)
if partial_strategy == "last":
split_texts = list(reversed(split_texts))
for _ in range(num_splits - 1):
split_texts.pop()
excluded.content = "".join(split_texts)
if token_counter(messages[:idx] + [excluded]) <= max_tokens:
if partial_strategy == "last":
excluded.content = "".join(reversed(split_texts))
messages = messages[:idx] + [excluded]
idx += 1
elif isinstance(excluded.content, list) and excluded.content:
for block in excluded.content:
if isinstance(block, str):
text = block
break
elif isinstance(block, dict) and block.get("type") == "text":
text = block.get("text")
break
if text:
if not copied:
excluded = excluded.model_copy(deep=True)
split_texts = text_splitter(text)
base_message_count = token_counter(messages[:idx])
if partial_strategy == "last":
split_texts = list(reversed(split_texts))
# Binary search for the maximum number of splits we can include
left, right = 0, len(split_texts)
max_iterations = len(split_texts).bit_length()
for _ in range(max_iterations):
if left >= right:
break
mid = (left + right + 1) // 2
excluded.content = "".join(split_texts[:mid])
if base_message_count + token_counter([excluded]) <= max_tokens:
left = mid
else:
right = mid - 1
if left > 0:
content_splits = split_texts[:left]
if partial_strategy == "last":
content_splits = list(reversed(content_splits))
excluded.content = "".join(content_splits)
messages = messages[:idx] + [excluded]
idx += 1
if end_on:
while idx > 0 and not _is_message_type(messages[idx - 1], end_on):
idx -= 1
for _ in range(idx):
if idx > 0 and not _is_message_type(messages[idx - 1], end_on):
idx -= 1
else:
break
return messages[:idx]
@ -1311,24 +1358,45 @@ def _last_max_tokens(
messages = list(messages)
if len(messages) == 0:
return []
if end_on:
while messages and not _is_message_type(messages[-1], end_on):
messages.pop()
swapped_system = include_system and isinstance(messages[0], SystemMessage)
reversed_ = messages[:1] + messages[1:][::-1] if swapped_system else messages[::-1]
reversed_ = _first_max_tokens(
reversed_,
max_tokens=max_tokens,
# Filter out messages after end_on type
if end_on:
for _ in range(len(messages)):
if not _is_message_type(messages[-1], end_on):
messages.pop()
else:
break
# Handle system message preservation
system_message = None
if include_system and len(messages) > 0 and isinstance(messages[0], SystemMessage):
system_message = messages[0]
messages = messages[1:]
# Reverse messages to use _first_max_tokens with reversed logic
reversed_messages = messages[::-1]
# Calculate remaining tokens after accounting for system message if present
remaining_tokens = max_tokens
if system_message:
system_tokens = token_counter([system_message])
remaining_tokens = max(0, max_tokens - system_tokens)
reversed_result = _first_max_tokens(
reversed_messages,
max_tokens=remaining_tokens,
token_counter=token_counter,
text_splitter=text_splitter,
partial_strategy="last" if allow_partial else None,
end_on=start_on,
)
if swapped_system:
return reversed_[:1] + reversed_[1:][::-1]
else:
return reversed_[::-1]
# Re-reverse the messages and add back the system message if needed
result = reversed_result[::-1]
if system_message:
result = [system_message] + result
return result
_MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {

View File

@ -455,6 +455,7 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
def test_trim_messages_partial_text_splitting() -> None:
messages = [HumanMessage(content="This is a long message that needs trimming")]
messages_copy = [m.model_copy(deep=True) for m in messages]
def count_characters(msgs: list[BaseMessage]) -> int:
return sum(len(m.content) if isinstance(m.content, str) else 0 for m in msgs)
@ -474,6 +475,7 @@ def test_trim_messages_partial_text_splitting() -> None:
assert len(result) == 1
assert result[0].content == "This is a " # First 10 characters
assert messages == messages_copy
def test_trim_messages_mixed_content_with_partial() -> None:
@ -485,6 +487,7 @@ def test_trim_messages_mixed_content_with_partial() -> None:
]
)
]
messages_copy = [m.model_copy(deep=True) for m in messages]
# Count total length of all text parts
def count_text_length(msgs: list[BaseMessage]) -> int:
@ -509,6 +512,7 @@ def test_trim_messages_mixed_content_with_partial() -> None:
assert len(result) == 1
assert len(result[0].content) == 1
assert result[0].content[0]["text"] == "First part of text."
assert messages == messages_copy
def test_trim_messages_exact_token_boundary() -> None:
@ -535,6 +539,7 @@ def test_trim_messages_exact_token_boundary() -> None:
strategy="first",
)
assert len(result2) == 2
assert result2 == messages
def test_trim_messages_start_on_with_allow_partial() -> None:
@ -543,7 +548,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
AIMessage(content="AI response"),
HumanMessage(content="Second human message"),
]
messages_copy = [m.model_copy(deep=True) for m in messages]
result = trim_messages(
messages,
max_tokens=20,
@ -555,6 +560,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
assert len(result) == 1
assert result[0].content == "Second human message"
assert messages == messages_copy
class FakeTokenCountingModel(FakeChatModel):