mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 03:27:55 +00:00
core[patch]: optimize trim_messages (#30327)
Refactored w/ Claude Up to 20x speedup! (with theoretical max improvement of `O(n / log n)`)
This commit is contained in:
parent
b78ae7817e
commit
07823cd41c
@ -824,10 +824,14 @@ def trim_messages(
|
||||
AIMessage( [{"type": "text", "text": "This is the FIRST 4 token block."}], id="second"),
|
||||
]
|
||||
""" # noqa: E501
|
||||
# Validate arguments
|
||||
if start_on and strategy == "first":
|
||||
raise ValueError
|
||||
msg = "start_on parameter is only valid with strategy='last'"
|
||||
raise ValueError(msg)
|
||||
if include_system and strategy == "first":
|
||||
raise ValueError
|
||||
msg = "include_system parameter is only valid with strategy='last'"
|
||||
raise ValueError(msg)
|
||||
|
||||
messages = convert_to_messages(messages)
|
||||
if hasattr(token_counter, "get_num_tokens_from_messages"):
|
||||
list_token_counter = token_counter.get_num_tokens_from_messages
|
||||
@ -1232,15 +1236,40 @@ def _first_max_tokens(
|
||||
messages = list(messages)
|
||||
if not messages:
|
||||
return messages
|
||||
idx = 0
|
||||
for i in range(len(messages)):
|
||||
if token_counter(messages[:-i] if i else messages) <= max_tokens:
|
||||
idx = len(messages) - i
|
||||
|
||||
# Check if all messages already fit within token limit
|
||||
if token_counter(messages) <= max_tokens:
|
||||
# When all messages fit, only apply end_on filtering if needed
|
||||
if end_on:
|
||||
for _ in range(len(messages)):
|
||||
if not _is_message_type(messages[-1], end_on):
|
||||
messages.pop()
|
||||
else:
|
||||
break
|
||||
return messages
|
||||
|
||||
# Use binary search to find the maximum number of messages within token limit
|
||||
left, right = 0, len(messages)
|
||||
max_iterations = len(messages).bit_length()
|
||||
for _ in range(max_iterations):
|
||||
if left >= right:
|
||||
break
|
||||
if partial_strategy and (idx < len(messages) - 1 or idx == 0):
|
||||
mid = (left + right + 1) // 2
|
||||
if token_counter(messages[:mid]) <= max_tokens:
|
||||
left = mid
|
||||
idx = mid
|
||||
else:
|
||||
right = mid - 1
|
||||
|
||||
# idx now contains the maximum number of complete messages we can include
|
||||
idx = left
|
||||
|
||||
if partial_strategy and idx < len(messages):
|
||||
included_partial = False
|
||||
copied = False
|
||||
if isinstance(messages[idx].content, list):
|
||||
excluded = messages[idx].model_copy(deep=True)
|
||||
copied = True
|
||||
num_block = len(excluded.content)
|
||||
if partial_strategy == "last":
|
||||
excluded.content = list(reversed(excluded.content))
|
||||
@ -1254,41 +1283,59 @@ def _first_max_tokens(
|
||||
if included_partial and partial_strategy == "last":
|
||||
excluded.content = list(reversed(excluded.content))
|
||||
if not included_partial:
|
||||
excluded = messages[idx].model_copy(deep=True)
|
||||
if isinstance(excluded.content, list) and any(
|
||||
isinstance(block, str) or block["type"] == "text"
|
||||
for block in messages[idx].content
|
||||
):
|
||||
text_block = next(
|
||||
block
|
||||
for block in messages[idx].content
|
||||
if isinstance(block, str) or block["type"] == "text"
|
||||
)
|
||||
text = (
|
||||
text_block["text"] if isinstance(text_block, dict) else text_block
|
||||
)
|
||||
elif isinstance(excluded.content, str):
|
||||
if not copied:
|
||||
excluded = messages[idx].model_copy(deep=True)
|
||||
copied = True
|
||||
|
||||
# Extract text content efficiently
|
||||
text = None
|
||||
if isinstance(excluded.content, str):
|
||||
text = excluded.content
|
||||
else:
|
||||
text = None
|
||||
if text:
|
||||
split_texts = text_splitter(text)
|
||||
num_splits = len(split_texts)
|
||||
if partial_strategy == "last":
|
||||
split_texts = list(reversed(split_texts))
|
||||
for _ in range(num_splits - 1):
|
||||
split_texts.pop()
|
||||
excluded.content = "".join(split_texts)
|
||||
if token_counter(messages[:idx] + [excluded]) <= max_tokens:
|
||||
if partial_strategy == "last":
|
||||
excluded.content = "".join(reversed(split_texts))
|
||||
messages = messages[:idx] + [excluded]
|
||||
idx += 1
|
||||
elif isinstance(excluded.content, list) and excluded.content:
|
||||
for block in excluded.content:
|
||||
if isinstance(block, str):
|
||||
text = block
|
||||
break
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text")
|
||||
break
|
||||
|
||||
if text:
|
||||
if not copied:
|
||||
excluded = excluded.model_copy(deep=True)
|
||||
|
||||
split_texts = text_splitter(text)
|
||||
base_message_count = token_counter(messages[:idx])
|
||||
if partial_strategy == "last":
|
||||
split_texts = list(reversed(split_texts))
|
||||
|
||||
# Binary search for the maximum number of splits we can include
|
||||
left, right = 0, len(split_texts)
|
||||
max_iterations = len(split_texts).bit_length()
|
||||
for _ in range(max_iterations):
|
||||
if left >= right:
|
||||
break
|
||||
mid = (left + right + 1) // 2
|
||||
excluded.content = "".join(split_texts[:mid])
|
||||
if base_message_count + token_counter([excluded]) <= max_tokens:
|
||||
left = mid
|
||||
else:
|
||||
right = mid - 1
|
||||
|
||||
if left > 0:
|
||||
content_splits = split_texts[:left]
|
||||
if partial_strategy == "last":
|
||||
content_splits = list(reversed(content_splits))
|
||||
excluded.content = "".join(content_splits)
|
||||
messages = messages[:idx] + [excluded]
|
||||
idx += 1
|
||||
|
||||
if end_on:
|
||||
while idx > 0 and not _is_message_type(messages[idx - 1], end_on):
|
||||
idx -= 1
|
||||
for _ in range(idx):
|
||||
if idx > 0 and not _is_message_type(messages[idx - 1], end_on):
|
||||
idx -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
return messages[:idx]
|
||||
|
||||
@ -1311,24 +1358,45 @@ def _last_max_tokens(
|
||||
messages = list(messages)
|
||||
if len(messages) == 0:
|
||||
return []
|
||||
if end_on:
|
||||
while messages and not _is_message_type(messages[-1], end_on):
|
||||
messages.pop()
|
||||
swapped_system = include_system and isinstance(messages[0], SystemMessage)
|
||||
reversed_ = messages[:1] + messages[1:][::-1] if swapped_system else messages[::-1]
|
||||
|
||||
reversed_ = _first_max_tokens(
|
||||
reversed_,
|
||||
max_tokens=max_tokens,
|
||||
# Filter out messages after end_on type
|
||||
if end_on:
|
||||
for _ in range(len(messages)):
|
||||
if not _is_message_type(messages[-1], end_on):
|
||||
messages.pop()
|
||||
else:
|
||||
break
|
||||
|
||||
# Handle system message preservation
|
||||
system_message = None
|
||||
if include_system and len(messages) > 0 and isinstance(messages[0], SystemMessage):
|
||||
system_message = messages[0]
|
||||
messages = messages[1:]
|
||||
|
||||
# Reverse messages to use _first_max_tokens with reversed logic
|
||||
reversed_messages = messages[::-1]
|
||||
|
||||
# Calculate remaining tokens after accounting for system message if present
|
||||
remaining_tokens = max_tokens
|
||||
if system_message:
|
||||
system_tokens = token_counter([system_message])
|
||||
remaining_tokens = max(0, max_tokens - system_tokens)
|
||||
|
||||
reversed_result = _first_max_tokens(
|
||||
reversed_messages,
|
||||
max_tokens=remaining_tokens,
|
||||
token_counter=token_counter,
|
||||
text_splitter=text_splitter,
|
||||
partial_strategy="last" if allow_partial else None,
|
||||
end_on=start_on,
|
||||
)
|
||||
if swapped_system:
|
||||
return reversed_[:1] + reversed_[1:][::-1]
|
||||
else:
|
||||
return reversed_[::-1]
|
||||
|
||||
# Re-reverse the messages and add back the system message if needed
|
||||
result = reversed_result[::-1]
|
||||
if system_message:
|
||||
result = [system_message] + result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {
|
||||
|
@ -455,6 +455,7 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
|
||||
|
||||
def test_trim_messages_partial_text_splitting() -> None:
|
||||
messages = [HumanMessage(content="This is a long message that needs trimming")]
|
||||
messages_copy = [m.model_copy(deep=True) for m in messages]
|
||||
|
||||
def count_characters(msgs: list[BaseMessage]) -> int:
|
||||
return sum(len(m.content) if isinstance(m.content, str) else 0 for m in msgs)
|
||||
@ -474,6 +475,7 @@ def test_trim_messages_partial_text_splitting() -> None:
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "This is a " # First 10 characters
|
||||
assert messages == messages_copy
|
||||
|
||||
|
||||
def test_trim_messages_mixed_content_with_partial() -> None:
|
||||
@ -485,6 +487,7 @@ def test_trim_messages_mixed_content_with_partial() -> None:
|
||||
]
|
||||
)
|
||||
]
|
||||
messages_copy = [m.model_copy(deep=True) for m in messages]
|
||||
|
||||
# Count total length of all text parts
|
||||
def count_text_length(msgs: list[BaseMessage]) -> int:
|
||||
@ -509,6 +512,7 @@ def test_trim_messages_mixed_content_with_partial() -> None:
|
||||
assert len(result) == 1
|
||||
assert len(result[0].content) == 1
|
||||
assert result[0].content[0]["text"] == "First part of text."
|
||||
assert messages == messages_copy
|
||||
|
||||
|
||||
def test_trim_messages_exact_token_boundary() -> None:
|
||||
@ -535,6 +539,7 @@ def test_trim_messages_exact_token_boundary() -> None:
|
||||
strategy="first",
|
||||
)
|
||||
assert len(result2) == 2
|
||||
assert result2 == messages
|
||||
|
||||
|
||||
def test_trim_messages_start_on_with_allow_partial() -> None:
|
||||
@ -543,7 +548,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
|
||||
AIMessage(content="AI response"),
|
||||
HumanMessage(content="Second human message"),
|
||||
]
|
||||
|
||||
messages_copy = [m.model_copy(deep=True) for m in messages]
|
||||
result = trim_messages(
|
||||
messages,
|
||||
max_tokens=20,
|
||||
@ -555,6 +560,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Second human message"
|
||||
assert messages == messages_copy
|
||||
|
||||
|
||||
class FakeTokenCountingModel(FakeChatModel):
|
||||
|
Loading…
Reference in New Issue
Block a user