diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index f0a5de99d4f..95344f58145 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -965,13 +965,13 @@ def test_convert_to_openai_messages_developer() -> None: result = convert_to_openai_messages(messages) assert result == [{"role": "developer", "content": "a"}] * 2 + def test_trim_messages_custom_counter() -> None: def dummy_token_counter(messages: list[BaseMessage]) -> int: # treat each message like it adds 3 default tokens at the beginning # of the message and at the end of the message. 3 + 4 + 3 = 10 tokens # per message. - default_content_len = 4 default_msg_prefix_len = 3 default_msg_suffix_len = 3 @@ -980,9 +980,13 @@ def test_trim_messages_custom_counter() -> None: if isinstance(msg.content, str): str_len = int(len(msg.content) / 3) count += default_msg_prefix_len + str_len + default_msg_suffix_len - if isinstance(msg.content, list): - content = int(len(msg.content[0]["text"]) / 3) + elif isinstance(msg.content, list): + content = int(len(msg.content[0]["text"])) count += default_msg_prefix_len + content + default_msg_suffix_len + else: + msg = "Invalid message content" + raise TypeError(msg) + return count messages = [