Compare commits

...

2 Commits

Author SHA1 Message Date
William Fu-Hinthorn
f7078e08da Support message trimming on single messages 2024-10-29 18:51:23 -07:00
William Fu-Hinthorn
0496c15123 [Anthropic] Shallow Copy 2024-10-04 07:41:18 -07:00
3 changed files with 29 additions and 8 deletions

View File

@@ -876,13 +876,14 @@ def _first_max_tokens(
] = None,
) -> list[BaseMessage]:
messages = list(messages)
if not messages:
return messages
idx = 0
for i in range(len(messages)):
if token_counter(messages[:-i] if i else messages) <= max_tokens:
idx = len(messages) - i
break
if idx < len(messages) - 1 and partial_strategy:
if partial_strategy and (idx < len(messages) - 1 or idx == 0):
included_partial = False
if isinstance(messages[idx].content, list):
excluded = messages[idx].model_copy(deep=True)

View File

@@ -299,6 +299,24 @@ def test_trim_messages_last_40_include_system_allow_partial_start_on_human() ->
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY
def test_trim_messages_allow_partial_one_message() -> None:
expected = [
HumanMessage("Th", id="third"),
]
actual = trim_messages(
[HumanMessage("This is a 4 token text.", id="third")],
max_tokens=2,
token_counter=lambda messages: sum(len(m.content) for m in messages),
text_splitter=lambda x: list(x),
strategy="first",
allow_partial=True,
)
assert actual == expected
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY
def test_trim_messages_allow_partial_text_splitter() -> None:
expected = [
HumanMessage("a 4 token text.", id="third"),

View File

@@ -1,3 +1,4 @@
import copy
import re
import warnings
from operator import itemgetter
@@ -116,7 +117,6 @@ def _merge_messages(
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
merged: list = []
for curr in messages:
curr = curr.model_copy(deep=True)
if isinstance(curr, ToolMessage):
if isinstance(curr.content, list) and all(
isinstance(block, dict) and block.get("type") == "tool_result"
@@ -139,12 +139,12 @@ def _merge_messages(
if isinstance(last.content, str):
new_content: List = [{"type": "text", "text": last.content}]
else:
new_content = last.content
new_content = copy.copy(last.content)
if isinstance(curr.content, str):
new_content.append({"type": "text", "text": curr.content})
else:
new_content.extend(curr.content)
last.content = new_content
merged[-1] = curr.model_copy(update={"content": new_content}, deep=False)
else:
merged.append(curr)
return merged
@@ -174,9 +174,11 @@ def _format_messages(
raise ValueError("System message must be at beginning of message list.")
if isinstance(message.content, list):
system = [
block
if isinstance(block, dict)
else {"type": "text", "text": "block"}
(
block
if isinstance(block, dict)
else {"type": "text", "text": "block"}
)
for block in message.content
]
else: