core[patch]: Allow bound models as token_counter in trim_messages (#25563)

This commit is contained in:
Bagatur
2024-08-20 00:21:22 -07:00
committed by GitHub
parent e01c6789c4
commit 4bd005adb6
2 changed files with 37 additions and 10 deletions

View File

@@ -2,6 +2,7 @@ from typing import Dict, List, Type
import pytest
from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
@@ -316,6 +317,19 @@ def test_trim_messages_invoke() -> None:
assert actual == expected
def test_trim_messages_bound_model_token_counter() -> None:
trimmer = trim_messages(
max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar")
)
trimmer.invoke([HumanMessage("foobar")])
def test_trim_messages_bad_token_counter() -> None:
trimmer = trim_messages(max_tokens=10, token_counter={})
with pytest.raises(ValueError):
trimmer.invoke([HumanMessage("foobar")])
def dummy_token_counter(messages: List[BaseMessage]) -> int:
# treat each message like it adds 3 default tokens at the beginning
# of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
@@ -338,3 +352,8 @@ def dummy_token_counter(messages: List[BaseMessage]) -> int:
+ default_msg_suffix_len
)
return count
class FakeTokenCountingModel(FakeChatModel):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return dummy_token_counter(messages)