diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index eaa04bf1afd..776f63e724d 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -720,7 +720,8 @@ def trim_messages( max_tokens: int, token_counter: Callable[[list[BaseMessage]], int] | Callable[[BaseMessage], int] - | BaseLanguageModel, + | BaseLanguageModel + | Literal["approximate"], strategy: Literal["first", "last"] = "last", allow_partial: bool = False, end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None, @@ -758,53 +759,65 @@ def trim_messages( messages: Sequence of Message-like objects to trim. max_tokens: Max token count of trimmed messages. token_counter: Function or llm for counting tokens in a `BaseMessage` or a - list of `BaseMessage`. If a `BaseLanguageModel` is passed in then - `BaseLanguageModel.get_num_tokens_from_messages()` will be used. - Set to `len` to count the number of **messages** in the chat history. + list of `BaseMessage`. + + If a `BaseLanguageModel` is passed in then + `BaseLanguageModel.get_num_tokens_from_messages()` will be used. Set to + `len` to count the number of **messages** in the chat history. + + You can also use string shortcuts for convenience: + + - `'approximate'`: Uses `count_tokens_approximately` for fast, approximate + token counts. !!! note - Use `count_tokens_approximately` to get fast, approximate token - counts. - - This is recommended for using `trim_messages` on the hot path, where - exact token counting is not necessary. + `count_tokens_approximately` (or the shortcut `'approximate'`) is + recommended for using `trim_messages` on the hot path, where exact token + counting is not necessary. strategy: Strategy for trimming. + - `'first'`: Keep the first `<= n_count` tokens of the messages. - `'last'`: Keep the last `<= n_count` tokens of the messages. allow_partial: Whether to split a message if only part of the message can be - included. If `strategy='last'` then the last partial contents of a message - are included. If `strategy='first'` then the first partial contents of a - message are included. - end_on: The message type to end on. If specified then every message after the - last occurrence of this type is ignored. If `strategy='last'` then this - is done before we attempt to get the last `max_tokens`. If - `strategy='first'` then this is done after we get the first - `max_tokens`. Can be specified as string names (e.g. `'system'`, - `'human'`, `'ai'`, ...) or as `BaseMessage` classes (e.g. - `SystemMessage`, `HumanMessage`, `AIMessage`, ...). Can be a single - type or a list of types. + included. - start_on: The message type to start on. Should only be specified if - `strategy='last'`. If specified then every message before - the first occurrence of this type is ignored. This is done after we trim - the initial messages to the last `max_tokens`. Does not - apply to a `SystemMessage` at index 0 if `include_system=True`. Can be - specified as string names (e.g. `'system'`, `'human'`, `'ai'`, ...) or - as `BaseMessage` classes (e.g. `SystemMessage`, `HumanMessage`, - `AIMessage`, ...). Can be a single type or a list of types. + If `strategy='last'` then the last partial contents of a message are + included. If `strategy='first'` then the first partial contents of a + message are included. + end_on: The message type to end on. + + If specified then every message after the last occurrence of this type is + ignored. If `strategy='last'` then this is done before we attempt to get the + last `max_tokens`. If `strategy='first'` then this is done after we get the + first `max_tokens`. Can be specified as string names (e.g. `'system'`, + `'human'`, `'ai'`, ...) or as `BaseMessage` classes (e.g. `SystemMessage`, + `HumanMessage`, `AIMessage`, ...). Can be a single type or a list of types. + + start_on: The message type to start on. + + Should only be specified if `strategy='last'`. If specified then every + message before the first occurrence of this type is ignored. This is done + after we trim the initial messages to the last `max_tokens`. Does not apply + to a `SystemMessage` at index 0 if `include_system=True`. Can be specified + as string names (e.g. `'system'`, `'human'`, `'ai'`, ...) or as + `BaseMessage` classes (e.g. `SystemMessage`, `HumanMessage`, `AIMessage`, + ...). Can be a single type or a list of types. include_system: Whether to keep the `SystemMessage` if there is one at index - `0`. Should only be specified if `strategy="last"`. + `0`. + + Should only be specified if `strategy="last"`. text_splitter: Function or `langchain_text_splitters.TextSplitter` for - splitting the string contents of a message. Only used if - `allow_partial=True`. If `strategy='last'` then the last split tokens - from a partial message will be included. if `strategy='first'` then the - first split tokens from a partial message will be included. Token splitter - assumes that separators are kept, so that split contents can be directly - concatenated to recreate the original text. Defaults to splitting on - newlines. + splitting the string contents of a message. + + Only used if `allow_partial=True`. If `strategy='last'` then the last split + tokens from a partial message will be included. if `strategy='first'` then + the first split tokens from a partial message will be included. Token + splitter assumes that separators are kept, so that split contents can be + directly concatenated to recreate the original text. Defaults to splitting + on newlines. Returns: List of trimmed `BaseMessage`. @@ -815,8 +828,8 @@ def trim_messages( Example: Trim chat history based on token count, keeping the `SystemMessage` if - present, and ensuring that the chat history starts with a `HumanMessage` ( - or a `SystemMessage` followed by a `HumanMessage`). + present, and ensuring that the chat history starts with a `HumanMessage` (or a + `SystemMessage` followed by a `HumanMessage`). ```python from langchain_core.messages import ( @@ -869,8 +882,34 @@ def trim_messages( ] ``` + Trim chat history using approximate token counting with `'approximate'`: + + ```python + trim_messages( + messages, + max_tokens=45, + strategy="last", + # Using the "approximate" shortcut for fast token counting + token_counter="approximate", + start_on="human", + include_system=True, + ) + + # This is equivalent to using `count_tokens_approximately` directly + from langchain_core.messages.utils import count_tokens_approximately + + trim_messages( + messages, + max_tokens=45, + strategy="last", + token_counter=count_tokens_approximately, + start_on="human", + include_system=True, + ) + ``` + Trim chat history based on the message count, keeping the `SystemMessage` if - present, and ensuring that the chat history starts with a `HumanMessage` ( + present, and ensuring that the chat history starts with a HumanMessage ( or a `SystemMessage` followed by a `HumanMessage`). trim_messages( @@ -992,24 +1031,44 @@ def trim_messages( raise ValueError(msg) messages = convert_to_messages(messages) - if hasattr(token_counter, "get_num_tokens_from_messages"): - list_token_counter = token_counter.get_num_tokens_from_messages - elif callable(token_counter): + + # Handle string shortcuts for token counter + if isinstance(token_counter, str): + if token_counter in _TOKEN_COUNTER_SHORTCUTS: + actual_token_counter = _TOKEN_COUNTER_SHORTCUTS[token_counter] + else: + available_shortcuts = ", ".join( + f"'{key}'" for key in _TOKEN_COUNTER_SHORTCUTS + ) + msg = ( + f"Invalid token_counter shortcut '{token_counter}'. " + f"Available shortcuts: {available_shortcuts}." + ) + raise ValueError(msg) + else: + # Type narrowing: at this point token_counter is not a str + actual_token_counter = token_counter # type: ignore[assignment] + + if hasattr(actual_token_counter, "get_num_tokens_from_messages"): + list_token_counter = actual_token_counter.get_num_tokens_from_messages + elif callable(actual_token_counter): if ( - next(iter(inspect.signature(token_counter).parameters.values())).annotation + next( + iter(inspect.signature(actual_token_counter).parameters.values()) + ).annotation is BaseMessage ): def list_token_counter(messages: Sequence[BaseMessage]) -> int: - return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc] + return sum(actual_token_counter(msg) for msg in messages) # type: ignore[arg-type, misc] else: - list_token_counter = token_counter + list_token_counter = actual_token_counter else: msg = ( f"'token_counter' expected to be a model that implements " f"'get_num_tokens_from_messages()' or a function. Received object of type " - f"{type(token_counter)}." + f"{type(actual_token_counter)}." ) raise ValueError(msg) @@ -1807,3 +1866,14 @@ def count_tokens_approximately( # round up once more time in case extra_tokens_per_message is a float return math.ceil(token_count) + + +# Mapping from string shortcuts to token counter functions +def _approximate_token_counter(messages: Sequence[BaseMessage]) -> int: + """Wrapper for `count_tokens_approximately` that matches expected signature.""" + return count_tokens_approximately(messages) + + +_TOKEN_COUNTER_SHORTCUTS = { + "approximate": _approximate_token_counter, +} diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index fbf949e9bba..c41cb9d65ed 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -673,6 +673,82 @@ def test_trim_messages_start_on_with_allow_partial() -> None: assert messages == messages_copy +def test_trim_messages_token_counter_shortcut_approximate() -> None: + """Test that `'approximate'` shortcut works for `token_counter`.""" + messages = [ + SystemMessage("This is a test message"), + HumanMessage("Another test message", id="first"), + AIMessage("AI response here", id="second"), + ] + messages_copy = [m.model_copy(deep=True) for m in messages] + + # Test using the "approximate" shortcut + result_shortcut = trim_messages( + messages, + max_tokens=50, + token_counter="approximate", + strategy="last", + ) + + # Test using count_tokens_approximately directly + result_direct = trim_messages( + messages, + max_tokens=50, + token_counter=count_tokens_approximately, + strategy="last", + ) + + # Both should produce the same result + assert result_shortcut == result_direct + assert messages == messages_copy + + +def test_trim_messages_token_counter_shortcut_invalid() -> None: + """Test that invalid `token_counter` shortcut raises `ValueError`.""" + messages = [ + SystemMessage("This is a test message"), + HumanMessage("Another test message"), + ] + + # Test with invalid shortcut - intentionally passing invalid string to verify + # runtime error handling for dynamically-constructed inputs + with pytest.raises(ValueError, match="Invalid token_counter shortcut 'invalid'"): + trim_messages( # type: ignore[call-overload] + messages, + max_tokens=50, + token_counter="invalid", + strategy="last", + ) + + +def test_trim_messages_token_counter_shortcut_with_options() -> None: + """Test that `'approximate'` shortcut works with different trim options.""" + messages = [ + SystemMessage("System instructions"), + HumanMessage("First human message", id="first"), + AIMessage("First AI response", id="ai1"), + HumanMessage("Second human message", id="second"), + AIMessage("Second AI response", id="ai2"), + ] + messages_copy = [m.model_copy(deep=True) for m in messages] + + # Test with various options + result = trim_messages( + messages, + max_tokens=100, + token_counter="approximate", + strategy="last", + include_system=True, + start_on="human", + ) + + # Should include system message and start on human + assert len(result) >= 2 + assert isinstance(result[0], SystemMessage) + assert any(isinstance(msg, HumanMessage) for msg in result[1:]) + assert messages == messages_copy + + class FakeTokenCountingModel(FakeChatModel): @override def get_num_tokens_from_messages(