1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-05 07:08:03 +00:00
This commit is contained in:
Eugene Yurtsev 2025-03-04 17:04:18 -05:00
parent c599ba47d5
commit 3a6cbb1551

View File

@ -964,3 +964,46 @@ def test_convert_to_openai_messages_developer() -> None:
]
result = convert_to_openai_messages(messages)
assert result == [{"role": "developer", "content": "a"}] * 2
def test_trim_messages_custom_counter() -> None:
def dummy_token_counter(messages: list[BaseMessage]) -> int:
# treat each message like it adds 3 default tokens at the beginning
# of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
# per message.
default_content_len = 4
default_msg_prefix_len = 3
default_msg_suffix_len = 3
count = 0
for msg in messages:
if isinstance(msg.content, str):
str_len = int(len(msg.content) / 3)
count += default_msg_prefix_len + str_len + default_msg_suffix_len
if isinstance(msg.content, list):
content = int(len(msg.content[0]["text"]) / 3)
count += default_msg_prefix_len + content + default_msg_suffix_len
return count
messages = [
SystemMessage("This is a 4 token text."),
HumanMessage("This is a 4 token text.", id="first"),
AIMessage(
[
{"type": "text", "text": "This is the FIRST 4 token block."},
{"type": "text", "text": "This is the SECOND 4 token block."},
],
id="second",
),
HumanMessage("This is a 4 token text.", id="third"),
AIMessage("This is a 4 token text.", id="fourth"),
]
new_messages = trim_messages(
messages,
max_tokens=10,
token_counter=dummy_token_counter,
strategy="last",
allow_partial=True,
)
raise ValueError(new_messages)