diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 8777b2674c5..f0a5de99d4f 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -964,3 +964,46 @@ def test_convert_to_openai_messages_developer() -> None: ] result = convert_to_openai_messages(messages) assert result == [{"role": "developer", "content": "a"}] * 2 + +def test_trim_messages_custom_counter() -> None: + def dummy_token_counter(messages: list[BaseMessage]) -> int: + # treat each message like it adds 3 default tokens at the beginning + # of the message and at the end of the message. 3 + 4 + 3 = 10 tokens + # per message. + + default_content_len = 4 + default_msg_prefix_len = 3 + default_msg_suffix_len = 3 + + count = 0 + for msg in messages: + if isinstance(msg.content, str): + str_len = int(len(msg.content) / 3) + count += default_msg_prefix_len + str_len + default_msg_suffix_len + if isinstance(msg.content, list): + content = int(len(msg.content[0]["text"]) / 3) + count += default_msg_prefix_len + content + default_msg_suffix_len + return count + + messages = [ + SystemMessage("This is a 4 token text."), + HumanMessage("This is a 4 token text.", id="first"), + AIMessage( + [ + {"type": "text", "text": "This is the FIRST 4 token block."}, + {"type": "text", "text": "This is the SECOND 4 token block."}, + ], + id="second", + ), + HumanMessage("This is a 4 token text.", id="third"), + AIMessage("This is a 4 token text.", id="fourth"), + ] + + new_messages = trim_messages( + messages, + max_tokens=10, + token_counter=dummy_token_counter, + strategy="last", + allow_partial=True, + ) + raise ValueError(new_messages)