mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 20:16:52 +00:00
core[patch]: Allow bound models as token_counter in trim_messages (#25563)
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import Dict, List, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@@ -316,6 +317,19 @@ def test_trim_messages_invoke() -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_trim_messages_bound_model_token_counter() -> None:
|
||||
trimmer = trim_messages(
|
||||
max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar")
|
||||
)
|
||||
trimmer.invoke([HumanMessage("foobar")])
|
||||
|
||||
|
||||
def test_trim_messages_bad_token_counter() -> None:
|
||||
trimmer = trim_messages(max_tokens=10, token_counter={})
|
||||
with pytest.raises(ValueError):
|
||||
trimmer.invoke([HumanMessage("foobar")])
|
||||
|
||||
|
||||
def dummy_token_counter(messages: List[BaseMessage]) -> int:
|
||||
# treat each message like it adds 3 default tokens at the beginning
|
||||
# of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
|
||||
@@ -338,3 +352,8 @@ def dummy_token_counter(messages: List[BaseMessage]) -> int:
|
||||
+ default_msg_suffix_len
|
||||
)
|
||||
return count
|
||||
|
||||
|
||||
class FakeTokenCountingModel(FakeChatModel):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
return dummy_token_counter(messages)
|
||||
|
Reference in New Issue
Block a user