core[patch]: optimize trim_messages (#30327)

Refactored w/ Claude

Up to 20x speedup! (with theoretical max improvement of `O(n / log n)`)
This commit is contained in:
Vadym Barda
2025-03-21 17:08:26 -04:00
committed by GitHub
parent b78ae7817e
commit 07823cd41c
2 changed files with 125 additions and 51 deletions

View File

@@ -455,6 +455,7 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
def test_trim_messages_partial_text_splitting() -> None:
messages = [HumanMessage(content="This is a long message that needs trimming")]
messages_copy = [m.model_copy(deep=True) for m in messages]
def count_characters(msgs: list[BaseMessage]) -> int:
return sum(len(m.content) if isinstance(m.content, str) else 0 for m in msgs)
@@ -474,6 +475,7 @@ def test_trim_messages_partial_text_splitting() -> None:
assert len(result) == 1
assert result[0].content == "This is a " # First 10 characters
assert messages == messages_copy
def test_trim_messages_mixed_content_with_partial() -> None:
@@ -485,6 +487,7 @@ def test_trim_messages_mixed_content_with_partial() -> None:
]
)
]
messages_copy = [m.model_copy(deep=True) for m in messages]
# Count total length of all text parts
def count_text_length(msgs: list[BaseMessage]) -> int:
@@ -509,6 +512,7 @@ def test_trim_messages_mixed_content_with_partial() -> None:
assert len(result) == 1
assert len(result[0].content) == 1
assert result[0].content[0]["text"] == "First part of text."
assert messages == messages_copy
def test_trim_messages_exact_token_boundary() -> None:
@@ -535,6 +539,7 @@ def test_trim_messages_exact_token_boundary() -> None:
strategy="first",
)
assert len(result2) == 2
assert result2 == messages
def test_trim_messages_start_on_with_allow_partial() -> None:
@@ -543,7 +548,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
AIMessage(content="AI response"),
HumanMessage(content="Second human message"),
]
messages_copy = [m.model_copy(deep=True) for m in messages]
result = trim_messages(
messages,
max_tokens=20,
@@ -555,6 +560,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
assert len(result) == 1
assert result[0].content == "Second human message"
assert messages == messages_copy
class FakeTokenCountingModel(FakeChatModel):