mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 05:20:39 +00:00
core[patch]: optimize trim_messages (#30327)
Refactored w/ Claude Up to 20x speedup! (with theoretical max improvement of `O(n / log n)`)
This commit is contained in:
parent
b78ae7817e
commit
07823cd41c
@ -824,10 +824,14 @@ def trim_messages(
|
|||||||
AIMessage( [{"type": "text", "text": "This is the FIRST 4 token block."}], id="second"),
|
AIMessage( [{"type": "text", "text": "This is the FIRST 4 token block."}], id="second"),
|
||||||
]
|
]
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
# Validate arguments
|
||||||
if start_on and strategy == "first":
|
if start_on and strategy == "first":
|
||||||
raise ValueError
|
msg = "start_on parameter is only valid with strategy='last'"
|
||||||
|
raise ValueError(msg)
|
||||||
if include_system and strategy == "first":
|
if include_system and strategy == "first":
|
||||||
raise ValueError
|
msg = "include_system parameter is only valid with strategy='last'"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
messages = convert_to_messages(messages)
|
messages = convert_to_messages(messages)
|
||||||
if hasattr(token_counter, "get_num_tokens_from_messages"):
|
if hasattr(token_counter, "get_num_tokens_from_messages"):
|
||||||
list_token_counter = token_counter.get_num_tokens_from_messages
|
list_token_counter = token_counter.get_num_tokens_from_messages
|
||||||
@ -1232,15 +1236,40 @@ def _first_max_tokens(
|
|||||||
messages = list(messages)
|
messages = list(messages)
|
||||||
if not messages:
|
if not messages:
|
||||||
return messages
|
return messages
|
||||||
idx = 0
|
|
||||||
for i in range(len(messages)):
|
# Check if all messages already fit within token limit
|
||||||
if token_counter(messages[:-i] if i else messages) <= max_tokens:
|
if token_counter(messages) <= max_tokens:
|
||||||
idx = len(messages) - i
|
# When all messages fit, only apply end_on filtering if needed
|
||||||
|
if end_on:
|
||||||
|
for _ in range(len(messages)):
|
||||||
|
if not _is_message_type(messages[-1], end_on):
|
||||||
|
messages.pop()
|
||||||
|
else:
|
||||||
break
|
break
|
||||||
if partial_strategy and (idx < len(messages) - 1 or idx == 0):
|
return messages
|
||||||
|
|
||||||
|
# Use binary search to find the maximum number of messages within token limit
|
||||||
|
left, right = 0, len(messages)
|
||||||
|
max_iterations = len(messages).bit_length()
|
||||||
|
for _ in range(max_iterations):
|
||||||
|
if left >= right:
|
||||||
|
break
|
||||||
|
mid = (left + right + 1) // 2
|
||||||
|
if token_counter(messages[:mid]) <= max_tokens:
|
||||||
|
left = mid
|
||||||
|
idx = mid
|
||||||
|
else:
|
||||||
|
right = mid - 1
|
||||||
|
|
||||||
|
# idx now contains the maximum number of complete messages we can include
|
||||||
|
idx = left
|
||||||
|
|
||||||
|
if partial_strategy and idx < len(messages):
|
||||||
included_partial = False
|
included_partial = False
|
||||||
|
copied = False
|
||||||
if isinstance(messages[idx].content, list):
|
if isinstance(messages[idx].content, list):
|
||||||
excluded = messages[idx].model_copy(deep=True)
|
excluded = messages[idx].model_copy(deep=True)
|
||||||
|
copied = True
|
||||||
num_block = len(excluded.content)
|
num_block = len(excluded.content)
|
||||||
if partial_strategy == "last":
|
if partial_strategy == "last":
|
||||||
excluded.content = list(reversed(excluded.content))
|
excluded.content = list(reversed(excluded.content))
|
||||||
@ -1254,41 +1283,59 @@ def _first_max_tokens(
|
|||||||
if included_partial and partial_strategy == "last":
|
if included_partial and partial_strategy == "last":
|
||||||
excluded.content = list(reversed(excluded.content))
|
excluded.content = list(reversed(excluded.content))
|
||||||
if not included_partial:
|
if not included_partial:
|
||||||
|
if not copied:
|
||||||
excluded = messages[idx].model_copy(deep=True)
|
excluded = messages[idx].model_copy(deep=True)
|
||||||
if isinstance(excluded.content, list) and any(
|
copied = True
|
||||||
isinstance(block, str) or block["type"] == "text"
|
|
||||||
for block in messages[idx].content
|
# Extract text content efficiently
|
||||||
):
|
|
||||||
text_block = next(
|
|
||||||
block
|
|
||||||
for block in messages[idx].content
|
|
||||||
if isinstance(block, str) or block["type"] == "text"
|
|
||||||
)
|
|
||||||
text = (
|
|
||||||
text_block["text"] if isinstance(text_block, dict) else text_block
|
|
||||||
)
|
|
||||||
elif isinstance(excluded.content, str):
|
|
||||||
text = excluded.content
|
|
||||||
else:
|
|
||||||
text = None
|
text = None
|
||||||
if text:
|
if isinstance(excluded.content, str):
|
||||||
split_texts = text_splitter(text)
|
text = excluded.content
|
||||||
num_splits = len(split_texts)
|
elif isinstance(excluded.content, list) and excluded.content:
|
||||||
if partial_strategy == "last":
|
for block in excluded.content:
|
||||||
split_texts = list(reversed(split_texts))
|
if isinstance(block, str):
|
||||||
for _ in range(num_splits - 1):
|
text = block
|
||||||
split_texts.pop()
|
break
|
||||||
excluded.content = "".join(split_texts)
|
elif isinstance(block, dict) and block.get("type") == "text":
|
||||||
if token_counter(messages[:idx] + [excluded]) <= max_tokens:
|
text = block.get("text")
|
||||||
if partial_strategy == "last":
|
|
||||||
excluded.content = "".join(reversed(split_texts))
|
|
||||||
messages = messages[:idx] + [excluded]
|
|
||||||
idx += 1
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if text:
|
||||||
|
if not copied:
|
||||||
|
excluded = excluded.model_copy(deep=True)
|
||||||
|
|
||||||
|
split_texts = text_splitter(text)
|
||||||
|
base_message_count = token_counter(messages[:idx])
|
||||||
|
if partial_strategy == "last":
|
||||||
|
split_texts = list(reversed(split_texts))
|
||||||
|
|
||||||
|
# Binary search for the maximum number of splits we can include
|
||||||
|
left, right = 0, len(split_texts)
|
||||||
|
max_iterations = len(split_texts).bit_length()
|
||||||
|
for _ in range(max_iterations):
|
||||||
|
if left >= right:
|
||||||
|
break
|
||||||
|
mid = (left + right + 1) // 2
|
||||||
|
excluded.content = "".join(split_texts[:mid])
|
||||||
|
if base_message_count + token_counter([excluded]) <= max_tokens:
|
||||||
|
left = mid
|
||||||
|
else:
|
||||||
|
right = mid - 1
|
||||||
|
|
||||||
|
if left > 0:
|
||||||
|
content_splits = split_texts[:left]
|
||||||
|
if partial_strategy == "last":
|
||||||
|
content_splits = list(reversed(content_splits))
|
||||||
|
excluded.content = "".join(content_splits)
|
||||||
|
messages = messages[:idx] + [excluded]
|
||||||
|
idx += 1
|
||||||
|
|
||||||
if end_on:
|
if end_on:
|
||||||
while idx > 0 and not _is_message_type(messages[idx - 1], end_on):
|
for _ in range(idx):
|
||||||
|
if idx > 0 and not _is_message_type(messages[idx - 1], end_on):
|
||||||
idx -= 1
|
idx -= 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
return messages[:idx]
|
return messages[:idx]
|
||||||
|
|
||||||
@ -1311,24 +1358,45 @@ def _last_max_tokens(
|
|||||||
messages = list(messages)
|
messages = list(messages)
|
||||||
if len(messages) == 0:
|
if len(messages) == 0:
|
||||||
return []
|
return []
|
||||||
if end_on:
|
|
||||||
while messages and not _is_message_type(messages[-1], end_on):
|
|
||||||
messages.pop()
|
|
||||||
swapped_system = include_system and isinstance(messages[0], SystemMessage)
|
|
||||||
reversed_ = messages[:1] + messages[1:][::-1] if swapped_system else messages[::-1]
|
|
||||||
|
|
||||||
reversed_ = _first_max_tokens(
|
# Filter out messages after end_on type
|
||||||
reversed_,
|
if end_on:
|
||||||
max_tokens=max_tokens,
|
for _ in range(len(messages)):
|
||||||
|
if not _is_message_type(messages[-1], end_on):
|
||||||
|
messages.pop()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Handle system message preservation
|
||||||
|
system_message = None
|
||||||
|
if include_system and len(messages) > 0 and isinstance(messages[0], SystemMessage):
|
||||||
|
system_message = messages[0]
|
||||||
|
messages = messages[1:]
|
||||||
|
|
||||||
|
# Reverse messages to use _first_max_tokens with reversed logic
|
||||||
|
reversed_messages = messages[::-1]
|
||||||
|
|
||||||
|
# Calculate remaining tokens after accounting for system message if present
|
||||||
|
remaining_tokens = max_tokens
|
||||||
|
if system_message:
|
||||||
|
system_tokens = token_counter([system_message])
|
||||||
|
remaining_tokens = max(0, max_tokens - system_tokens)
|
||||||
|
|
||||||
|
reversed_result = _first_max_tokens(
|
||||||
|
reversed_messages,
|
||||||
|
max_tokens=remaining_tokens,
|
||||||
token_counter=token_counter,
|
token_counter=token_counter,
|
||||||
text_splitter=text_splitter,
|
text_splitter=text_splitter,
|
||||||
partial_strategy="last" if allow_partial else None,
|
partial_strategy="last" if allow_partial else None,
|
||||||
end_on=start_on,
|
end_on=start_on,
|
||||||
)
|
)
|
||||||
if swapped_system:
|
|
||||||
return reversed_[:1] + reversed_[1:][::-1]
|
# Re-reverse the messages and add back the system message if needed
|
||||||
else:
|
result = reversed_result[::-1]
|
||||||
return reversed_[::-1]
|
if system_message:
|
||||||
|
result = [system_message] + result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
_MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {
|
_MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {
|
||||||
|
@ -455,6 +455,7 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
|
|||||||
|
|
||||||
def test_trim_messages_partial_text_splitting() -> None:
|
def test_trim_messages_partial_text_splitting() -> None:
|
||||||
messages = [HumanMessage(content="This is a long message that needs trimming")]
|
messages = [HumanMessage(content="This is a long message that needs trimming")]
|
||||||
|
messages_copy = [m.model_copy(deep=True) for m in messages]
|
||||||
|
|
||||||
def count_characters(msgs: list[BaseMessage]) -> int:
|
def count_characters(msgs: list[BaseMessage]) -> int:
|
||||||
return sum(len(m.content) if isinstance(m.content, str) else 0 for m in msgs)
|
return sum(len(m.content) if isinstance(m.content, str) else 0 for m in msgs)
|
||||||
@ -474,6 +475,7 @@ def test_trim_messages_partial_text_splitting() -> None:
|
|||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].content == "This is a " # First 10 characters
|
assert result[0].content == "This is a " # First 10 characters
|
||||||
|
assert messages == messages_copy
|
||||||
|
|
||||||
|
|
||||||
def test_trim_messages_mixed_content_with_partial() -> None:
|
def test_trim_messages_mixed_content_with_partial() -> None:
|
||||||
@ -485,6 +487,7 @@ def test_trim_messages_mixed_content_with_partial() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
messages_copy = [m.model_copy(deep=True) for m in messages]
|
||||||
|
|
||||||
# Count total length of all text parts
|
# Count total length of all text parts
|
||||||
def count_text_length(msgs: list[BaseMessage]) -> int:
|
def count_text_length(msgs: list[BaseMessage]) -> int:
|
||||||
@ -509,6 +512,7 @@ def test_trim_messages_mixed_content_with_partial() -> None:
|
|||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert len(result[0].content) == 1
|
assert len(result[0].content) == 1
|
||||||
assert result[0].content[0]["text"] == "First part of text."
|
assert result[0].content[0]["text"] == "First part of text."
|
||||||
|
assert messages == messages_copy
|
||||||
|
|
||||||
|
|
||||||
def test_trim_messages_exact_token_boundary() -> None:
|
def test_trim_messages_exact_token_boundary() -> None:
|
||||||
@ -535,6 +539,7 @@ def test_trim_messages_exact_token_boundary() -> None:
|
|||||||
strategy="first",
|
strategy="first",
|
||||||
)
|
)
|
||||||
assert len(result2) == 2
|
assert len(result2) == 2
|
||||||
|
assert result2 == messages
|
||||||
|
|
||||||
|
|
||||||
def test_trim_messages_start_on_with_allow_partial() -> None:
|
def test_trim_messages_start_on_with_allow_partial() -> None:
|
||||||
@ -543,7 +548,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
|
|||||||
AIMessage(content="AI response"),
|
AIMessage(content="AI response"),
|
||||||
HumanMessage(content="Second human message"),
|
HumanMessage(content="Second human message"),
|
||||||
]
|
]
|
||||||
|
messages_copy = [m.model_copy(deep=True) for m in messages]
|
||||||
result = trim_messages(
|
result = trim_messages(
|
||||||
messages,
|
messages,
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
@ -555,6 +560,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
|
|||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].content == "Second human message"
|
assert result[0].content == "Second human message"
|
||||||
|
assert messages == messages_copy
|
||||||
|
|
||||||
|
|
||||||
class FakeTokenCountingModel(FakeChatModel):
|
class FakeTokenCountingModel(FakeChatModel):
|
||||||
|
Loading…
Reference in New Issue
Block a user