mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-27 03:31:51 +00:00
x
This commit is contained in:
parent
3a6cbb1551
commit
be438bae36
@ -965,13 +965,13 @@ def test_convert_to_openai_messages_developer() -> None:
|
||||
result = convert_to_openai_messages(messages)
|
||||
assert result == [{"role": "developer", "content": "a"}] * 2
|
||||
|
||||
|
||||
def test_trim_messages_custom_counter() -> None:
|
||||
def dummy_token_counter(messages: list[BaseMessage]) -> int:
|
||||
# treat each message like it adds 3 default tokens at the beginning
|
||||
# of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
|
||||
# per message.
|
||||
|
||||
default_content_len = 4
|
||||
default_msg_prefix_len = 3
|
||||
default_msg_suffix_len = 3
|
||||
|
||||
@ -980,9 +980,13 @@ def test_trim_messages_custom_counter() -> None:
|
||||
if isinstance(msg.content, str):
|
||||
str_len = int(len(msg.content) / 3)
|
||||
count += default_msg_prefix_len + str_len + default_msg_suffix_len
|
||||
if isinstance(msg.content, list):
|
||||
content = int(len(msg.content[0]["text"]) / 3)
|
||||
elif isinstance(msg.content, list):
|
||||
content = int(len(msg.content[0]["text"]))
|
||||
count += default_msg_prefix_len + content + default_msg_suffix_len
|
||||
else:
|
||||
msg = "Invalid message content"
|
||||
raise TypeError(msg)
|
||||
|
||||
return count
|
||||
|
||||
messages = [
|
||||
|
Loading…
Reference in New Issue
Block a user