From 4bd005adb69e5bc8b85a1e1504ba1d2b01316bf9 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 20 Aug 2024 00:21:22 -0700 Subject: [PATCH] core[patch]: Allow bound models as token_counter in trim_messages (#25563) --- libs/core/langchain_core/messages/utils.py | 28 ++++++++++++------- .../tests/unit_tests/messages/test_utils.py | 19 +++++++++++++ 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index b22c927ad4d..18225c53228 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -514,6 +514,8 @@ def merge_message_runs( return merged +# TODO: Update so validation errors (for token_counter, for example) are raised on +# init not at runtime. @_runnable_support def trim_messages( messages: Union[Iterable[MessageLikeRepresentation], PromptValue], @@ -759,24 +761,30 @@ def trim_messages( AIMessage("This is a 4 token text. The full message is 10 tokens.", id="fourth"), ] """ # noqa: E501 - from langchain_core.language_models import BaseLanguageModel if start_on and strategy == "first": raise ValueError if include_system and strategy == "first": raise ValueError messages = convert_to_messages(messages) - if isinstance(token_counter, BaseLanguageModel): - list_token_counter = token_counter.get_num_tokens_from_messages - elif ( - list(inspect.signature(token_counter).parameters.values())[0].annotation - is BaseMessage - ): + if hasattr(token_counter, "get_num_tokens_from_messages"): + list_token_counter = getattr(token_counter, "get_num_tokens_from_messages") + elif callable(token_counter): + if ( + list(inspect.signature(token_counter).parameters.values())[0].annotation + is BaseMessage + ): - def list_token_counter(messages: Sequence[BaseMessage]) -> int: - return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc] + def list_token_counter(messages: Sequence[BaseMessage]) -> int: + return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc] + else: + list_token_counter = token_counter # type: ignore[assignment] else: - list_token_counter = token_counter # type: ignore[assignment] + raise ValueError( + f"'token_counter' expected ot be a model that implements " + f"'get_num_tokens_from_messages()' or a function. Received object of type " + f"{type(token_counter)}." + ) try: from langchain_text_splitters import TextSplitter diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index d3160cb47ed..142272292f6 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -2,6 +2,7 @@ from typing import Dict, List, Type import pytest +from langchain_core.language_models.fake_chat_models import FakeChatModel from langchain_core.messages import ( AIMessage, BaseMessage, @@ -316,6 +317,19 @@ def test_trim_messages_invoke() -> None: assert actual == expected +def test_trim_messages_bound_model_token_counter() -> None: + trimmer = trim_messages( + max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar") + ) + trimmer.invoke([HumanMessage("foobar")]) + + +def test_trim_messages_bad_token_counter() -> None: + trimmer = trim_messages(max_tokens=10, token_counter={}) + with pytest.raises(ValueError): + trimmer.invoke([HumanMessage("foobar")]) + + def dummy_token_counter(messages: List[BaseMessage]) -> int: # treat each message like it adds 3 default tokens at the beginning # of the message and at the end of the message. 3 + 4 + 3 = 10 tokens @@ -338,3 +352,8 @@ def dummy_token_counter(messages: List[BaseMessage]) -> int: + default_msg_suffix_len ) return count + + +class FakeTokenCountingModel(FakeChatModel): + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return dummy_token_counter(messages)