mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-05 07:08:03 +00:00
x
This commit is contained in:
parent
c599ba47d5
commit
3a6cbb1551
@ -964,3 +964,46 @@ def test_convert_to_openai_messages_developer() -> None:
|
||||
]
|
||||
result = convert_to_openai_messages(messages)
|
||||
assert result == [{"role": "developer", "content": "a"}] * 2
|
||||
|
||||
def test_trim_messages_custom_counter() -> None:
|
||||
def dummy_token_counter(messages: list[BaseMessage]) -> int:
|
||||
# treat each message like it adds 3 default tokens at the beginning
|
||||
# of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
|
||||
# per message.
|
||||
|
||||
default_content_len = 4
|
||||
default_msg_prefix_len = 3
|
||||
default_msg_suffix_len = 3
|
||||
|
||||
count = 0
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, str):
|
||||
str_len = int(len(msg.content) / 3)
|
||||
count += default_msg_prefix_len + str_len + default_msg_suffix_len
|
||||
if isinstance(msg.content, list):
|
||||
content = int(len(msg.content[0]["text"]) / 3)
|
||||
count += default_msg_prefix_len + content + default_msg_suffix_len
|
||||
return count
|
||||
|
||||
messages = [
|
||||
SystemMessage("This is a 4 token text."),
|
||||
HumanMessage("This is a 4 token text.", id="first"),
|
||||
AIMessage(
|
||||
[
|
||||
{"type": "text", "text": "This is the FIRST 4 token block."},
|
||||
{"type": "text", "text": "This is the SECOND 4 token block."},
|
||||
],
|
||||
id="second",
|
||||
),
|
||||
HumanMessage("This is a 4 token text.", id="third"),
|
||||
AIMessage("This is a 4 token text.", id="fourth"),
|
||||
]
|
||||
|
||||
new_messages = trim_messages(
|
||||
messages,
|
||||
max_tokens=10,
|
||||
token_counter=dummy_token_counter,
|
||||
strategy="last",
|
||||
allow_partial=True,
|
||||
)
|
||||
raise ValueError(new_messages)
|
||||
|
Loading…
Reference in New Issue
Block a user