From 4fbeffcfeea022ce7a5a9eaafd0ba4d70ccc0bca Mon Sep 17 00:00:00 2001
From: James <151067328+Davda-James@users.noreply.github.com>
Date: Sat, 20 Dec 2025 06:55:29 +0530
Subject: [PATCH] feat(core): add `'approximate'` alias in place of
`count_tokens_approximately` (#33045)
### Description:
earlier we have to use like below:
```python
from langchain_core.messages import trim_messages
from langchain_core.messages.utils import count_tokens_approximately
trim_messages(..., token_counter=count_tokens_approximately)
```
Now can be used as like this also
```python
from langchain_core.messages import trim_messages
trim_messages(..., token_counter="approximate")
```
- [x] **Added tests**
- [x] **Lint and test**: Run this as I made change in langchain/core, uv
run --group test pytest tests/unit_tests/messages/test_utils.py -v
---------
Co-authored-by: Mason Daugherty
Co-authored-by: Mason Daugherty
---
libs/core/langchain_core/messages/utils.py | 162 +++++++++++++-----
.../tests/unit_tests/messages/test_utils.py | 76 ++++++++
2 files changed, 192 insertions(+), 46 deletions(-)
diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py
index eaa04bf1afd..776f63e724d 100644
--- a/libs/core/langchain_core/messages/utils.py
+++ b/libs/core/langchain_core/messages/utils.py
@@ -720,7 +720,8 @@ def trim_messages(
max_tokens: int,
token_counter: Callable[[list[BaseMessage]], int]
| Callable[[BaseMessage], int]
- | BaseLanguageModel,
+ | BaseLanguageModel
+ | Literal["approximate"],
strategy: Literal["first", "last"] = "last",
allow_partial: bool = False,
end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
@@ -758,53 +759,65 @@ def trim_messages(
messages: Sequence of Message-like objects to trim.
max_tokens: Max token count of trimmed messages.
token_counter: Function or llm for counting tokens in a `BaseMessage` or a
- list of `BaseMessage`. If a `BaseLanguageModel` is passed in then
- `BaseLanguageModel.get_num_tokens_from_messages()` will be used.
- Set to `len` to count the number of **messages** in the chat history.
+ list of `BaseMessage`.
+
+ If a `BaseLanguageModel` is passed in then
+ `BaseLanguageModel.get_num_tokens_from_messages()` will be used. Set to
+ `len` to count the number of **messages** in the chat history.
+
+ You can also use string shortcuts for convenience:
+
+ - `'approximate'`: Uses `count_tokens_approximately` for fast, approximate
+ token counts.
!!! note
- Use `count_tokens_approximately` to get fast, approximate token
- counts.
-
- This is recommended for using `trim_messages` on the hot path, where
- exact token counting is not necessary.
+ `count_tokens_approximately` (or the shortcut `'approximate'`) is
+ recommended for using `trim_messages` on the hot path, where exact token
+ counting is not necessary.
strategy: Strategy for trimming.
+
- `'first'`: Keep the first `<= n_count` tokens of the messages.
- `'last'`: Keep the last `<= n_count` tokens of the messages.
allow_partial: Whether to split a message if only part of the message can be
- included. If `strategy='last'` then the last partial contents of a message
- are included. If `strategy='first'` then the first partial contents of a
- message are included.
- end_on: The message type to end on. If specified then every message after the
- last occurrence of this type is ignored. If `strategy='last'` then this
- is done before we attempt to get the last `max_tokens`. If
- `strategy='first'` then this is done after we get the first
- `max_tokens`. Can be specified as string names (e.g. `'system'`,
- `'human'`, `'ai'`, ...) or as `BaseMessage` classes (e.g.
- `SystemMessage`, `HumanMessage`, `AIMessage`, ...). Can be a single
- type or a list of types.
+ included.
- start_on: The message type to start on. Should only be specified if
- `strategy='last'`. If specified then every message before
- the first occurrence of this type is ignored. This is done after we trim
- the initial messages to the last `max_tokens`. Does not
- apply to a `SystemMessage` at index 0 if `include_system=True`. Can be
- specified as string names (e.g. `'system'`, `'human'`, `'ai'`, ...) or
- as `BaseMessage` classes (e.g. `SystemMessage`, `HumanMessage`,
- `AIMessage`, ...). Can be a single type or a list of types.
+ If `strategy='last'` then the last partial contents of a message are
+ included. If `strategy='first'` then the first partial contents of a
+ message are included.
+ end_on: The message type to end on.
+
+ If specified then every message after the last occurrence of this type is
+ ignored. If `strategy='last'` then this is done before we attempt to get the
+ last `max_tokens`. If `strategy='first'` then this is done after we get the
+ first `max_tokens`. Can be specified as string names (e.g. `'system'`,
+ `'human'`, `'ai'`, ...) or as `BaseMessage` classes (e.g. `SystemMessage`,
+ `HumanMessage`, `AIMessage`, ...). Can be a single type or a list of types.
+
+ start_on: The message type to start on.
+
+ Should only be specified if `strategy='last'`. If specified then every
+ message before the first occurrence of this type is ignored. This is done
+ after we trim the initial messages to the last `max_tokens`. Does not apply
+ to a `SystemMessage` at index 0 if `include_system=True`. Can be specified
+ as string names (e.g. `'system'`, `'human'`, `'ai'`, ...) or as
+ `BaseMessage` classes (e.g. `SystemMessage`, `HumanMessage`, `AIMessage`,
+ ...). Can be a single type or a list of types.
include_system: Whether to keep the `SystemMessage` if there is one at index
- `0`. Should only be specified if `strategy="last"`.
+ `0`.
+
+ Should only be specified if `strategy="last"`.
text_splitter: Function or `langchain_text_splitters.TextSplitter` for
- splitting the string contents of a message. Only used if
- `allow_partial=True`. If `strategy='last'` then the last split tokens
- from a partial message will be included. if `strategy='first'` then the
- first split tokens from a partial message will be included. Token splitter
- assumes that separators are kept, so that split contents can be directly
- concatenated to recreate the original text. Defaults to splitting on
- newlines.
+ splitting the string contents of a message.
+
+ Only used if `allow_partial=True`. If `strategy='last'` then the last split
+ tokens from a partial message will be included. if `strategy='first'` then
+ the first split tokens from a partial message will be included. Token
+ splitter assumes that separators are kept, so that split contents can be
+ directly concatenated to recreate the original text. Defaults to splitting
+ on newlines.
Returns:
List of trimmed `BaseMessage`.
@@ -815,8 +828,8 @@ def trim_messages(
Example:
Trim chat history based on token count, keeping the `SystemMessage` if
- present, and ensuring that the chat history starts with a `HumanMessage` (
- or a `SystemMessage` followed by a `HumanMessage`).
+ present, and ensuring that the chat history starts with a `HumanMessage` (or a
+ `SystemMessage` followed by a `HumanMessage`).
```python
from langchain_core.messages import (
@@ -869,8 +882,34 @@ def trim_messages(
]
```
+ Trim chat history using approximate token counting with `'approximate'`:
+
+ ```python
+ trim_messages(
+ messages,
+ max_tokens=45,
+ strategy="last",
+ # Using the "approximate" shortcut for fast token counting
+ token_counter="approximate",
+ start_on="human",
+ include_system=True,
+ )
+
+ # This is equivalent to using `count_tokens_approximately` directly
+ from langchain_core.messages.utils import count_tokens_approximately
+
+ trim_messages(
+ messages,
+ max_tokens=45,
+ strategy="last",
+ token_counter=count_tokens_approximately,
+ start_on="human",
+ include_system=True,
+ )
+ ```
+
Trim chat history based on the message count, keeping the `SystemMessage` if
- present, and ensuring that the chat history starts with a `HumanMessage` (
+ present, and ensuring that the chat history starts with a HumanMessage (
or a `SystemMessage` followed by a `HumanMessage`).
trim_messages(
@@ -992,24 +1031,44 @@ def trim_messages(
raise ValueError(msg)
messages = convert_to_messages(messages)
- if hasattr(token_counter, "get_num_tokens_from_messages"):
- list_token_counter = token_counter.get_num_tokens_from_messages
- elif callable(token_counter):
+
+ # Handle string shortcuts for token counter
+ if isinstance(token_counter, str):
+ if token_counter in _TOKEN_COUNTER_SHORTCUTS:
+ actual_token_counter = _TOKEN_COUNTER_SHORTCUTS[token_counter]
+ else:
+ available_shortcuts = ", ".join(
+ f"'{key}'" for key in _TOKEN_COUNTER_SHORTCUTS
+ )
+ msg = (
+ f"Invalid token_counter shortcut '{token_counter}'. "
+ f"Available shortcuts: {available_shortcuts}."
+ )
+ raise ValueError(msg)
+ else:
+ # Type narrowing: at this point token_counter is not a str
+ actual_token_counter = token_counter # type: ignore[assignment]
+
+ if hasattr(actual_token_counter, "get_num_tokens_from_messages"):
+ list_token_counter = actual_token_counter.get_num_tokens_from_messages
+ elif callable(actual_token_counter):
if (
- next(iter(inspect.signature(token_counter).parameters.values())).annotation
+ next(
+ iter(inspect.signature(actual_token_counter).parameters.values())
+ ).annotation
is BaseMessage
):
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
- return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
+ return sum(actual_token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
else:
- list_token_counter = token_counter
+ list_token_counter = actual_token_counter
else:
msg = (
f"'token_counter' expected to be a model that implements "
f"'get_num_tokens_from_messages()' or a function. Received object of type "
- f"{type(token_counter)}."
+ f"{type(actual_token_counter)}."
)
raise ValueError(msg)
@@ -1807,3 +1866,14 @@ def count_tokens_approximately(
# round up once more time in case extra_tokens_per_message is a float
return math.ceil(token_count)
+
+
+# Mapping from string shortcuts to token counter functions
+def _approximate_token_counter(messages: Sequence[BaseMessage]) -> int:
+ """Wrapper for `count_tokens_approximately` that matches expected signature."""
+ return count_tokens_approximately(messages)
+
+
+_TOKEN_COUNTER_SHORTCUTS = {
+ "approximate": _approximate_token_counter,
+}
diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py
index fbf949e9bba..c41cb9d65ed 100644
--- a/libs/core/tests/unit_tests/messages/test_utils.py
+++ b/libs/core/tests/unit_tests/messages/test_utils.py
@@ -673,6 +673,82 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
assert messages == messages_copy
+def test_trim_messages_token_counter_shortcut_approximate() -> None:
+ """Test that `'approximate'` shortcut works for `token_counter`."""
+ messages = [
+ SystemMessage("This is a test message"),
+ HumanMessage("Another test message", id="first"),
+ AIMessage("AI response here", id="second"),
+ ]
+ messages_copy = [m.model_copy(deep=True) for m in messages]
+
+ # Test using the "approximate" shortcut
+ result_shortcut = trim_messages(
+ messages,
+ max_tokens=50,
+ token_counter="approximate",
+ strategy="last",
+ )
+
+ # Test using count_tokens_approximately directly
+ result_direct = trim_messages(
+ messages,
+ max_tokens=50,
+ token_counter=count_tokens_approximately,
+ strategy="last",
+ )
+
+ # Both should produce the same result
+ assert result_shortcut == result_direct
+ assert messages == messages_copy
+
+
+def test_trim_messages_token_counter_shortcut_invalid() -> None:
+ """Test that invalid `token_counter` shortcut raises `ValueError`."""
+ messages = [
+ SystemMessage("This is a test message"),
+ HumanMessage("Another test message"),
+ ]
+
+ # Test with invalid shortcut - intentionally passing invalid string to verify
+ # runtime error handling for dynamically-constructed inputs
+ with pytest.raises(ValueError, match="Invalid token_counter shortcut 'invalid'"):
+ trim_messages( # type: ignore[call-overload]
+ messages,
+ max_tokens=50,
+ token_counter="invalid",
+ strategy="last",
+ )
+
+
+def test_trim_messages_token_counter_shortcut_with_options() -> None:
+ """Test that `'approximate'` shortcut works with different trim options."""
+ messages = [
+ SystemMessage("System instructions"),
+ HumanMessage("First human message", id="first"),
+ AIMessage("First AI response", id="ai1"),
+ HumanMessage("Second human message", id="second"),
+ AIMessage("Second AI response", id="ai2"),
+ ]
+ messages_copy = [m.model_copy(deep=True) for m in messages]
+
+ # Test with various options
+ result = trim_messages(
+ messages,
+ max_tokens=100,
+ token_counter="approximate",
+ strategy="last",
+ include_system=True,
+ start_on="human",
+ )
+
+ # Should include system message and start on human
+ assert len(result) >= 2
+ assert isinstance(result[0], SystemMessage)
+ assert any(isinstance(msg, HumanMessage) for msg in result[1:])
+ assert messages == messages_copy
+
+
class FakeTokenCountingModel(FakeChatModel):
@override
def get_num_tokens_from_messages(