mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 02:53:16 +00:00
feat(core): add 'approximate' alias in place of count_tokens_approximately (#33045)
### Description: earlier we have to use like below: ```python from langchain_core.messages import trim_messages from langchain_core.messages.utils import count_tokens_approximately trim_messages(..., token_counter=count_tokens_approximately) ``` Now can be used as like this also ```python from langchain_core.messages import trim_messages trim_messages(..., token_counter="approximate") ``` - [x] **Added tests** - [x] **Lint and test**: Run this as I made change in langchain/core, uv run --group test pytest tests/unit_tests/messages/test_utils.py -v <img width="1006" height="66" alt="image" src="https://github.com/user-attachments/assets/c6938c29-a781-4e7f-871b-8e888ee764b7" /> --------- Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
@@ -720,7 +720,8 @@ def trim_messages(
|
||||
max_tokens: int,
|
||||
token_counter: Callable[[list[BaseMessage]], int]
|
||||
| Callable[[BaseMessage], int]
|
||||
| BaseLanguageModel,
|
||||
| BaseLanguageModel
|
||||
| Literal["approximate"],
|
||||
strategy: Literal["first", "last"] = "last",
|
||||
allow_partial: bool = False,
|
||||
end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
|
||||
@@ -758,53 +759,65 @@ def trim_messages(
|
||||
messages: Sequence of Message-like objects to trim.
|
||||
max_tokens: Max token count of trimmed messages.
|
||||
token_counter: Function or llm for counting tokens in a `BaseMessage` or a
|
||||
list of `BaseMessage`. If a `BaseLanguageModel` is passed in then
|
||||
`BaseLanguageModel.get_num_tokens_from_messages()` will be used.
|
||||
Set to `len` to count the number of **messages** in the chat history.
|
||||
list of `BaseMessage`.
|
||||
|
||||
If a `BaseLanguageModel` is passed in then
|
||||
`BaseLanguageModel.get_num_tokens_from_messages()` will be used. Set to
|
||||
`len` to count the number of **messages** in the chat history.
|
||||
|
||||
You can also use string shortcuts for convenience:
|
||||
|
||||
- `'approximate'`: Uses `count_tokens_approximately` for fast, approximate
|
||||
token counts.
|
||||
|
||||
!!! note
|
||||
|
||||
Use `count_tokens_approximately` to get fast, approximate token
|
||||
counts.
|
||||
|
||||
This is recommended for using `trim_messages` on the hot path, where
|
||||
exact token counting is not necessary.
|
||||
`count_tokens_approximately` (or the shortcut `'approximate'`) is
|
||||
recommended for using `trim_messages` on the hot path, where exact token
|
||||
counting is not necessary.
|
||||
|
||||
strategy: Strategy for trimming.
|
||||
|
||||
- `'first'`: Keep the first `<= n_count` tokens of the messages.
|
||||
- `'last'`: Keep the last `<= n_count` tokens of the messages.
|
||||
allow_partial: Whether to split a message if only part of the message can be
|
||||
included. If `strategy='last'` then the last partial contents of a message
|
||||
are included. If `strategy='first'` then the first partial contents of a
|
||||
message are included.
|
||||
end_on: The message type to end on. If specified then every message after the
|
||||
last occurrence of this type is ignored. If `strategy='last'` then this
|
||||
is done before we attempt to get the last `max_tokens`. If
|
||||
`strategy='first'` then this is done after we get the first
|
||||
`max_tokens`. Can be specified as string names (e.g. `'system'`,
|
||||
`'human'`, `'ai'`, ...) or as `BaseMessage` classes (e.g.
|
||||
`SystemMessage`, `HumanMessage`, `AIMessage`, ...). Can be a single
|
||||
type or a list of types.
|
||||
included.
|
||||
|
||||
start_on: The message type to start on. Should only be specified if
|
||||
`strategy='last'`. If specified then every message before
|
||||
the first occurrence of this type is ignored. This is done after we trim
|
||||
the initial messages to the last `max_tokens`. Does not
|
||||
apply to a `SystemMessage` at index 0 if `include_system=True`. Can be
|
||||
specified as string names (e.g. `'system'`, `'human'`, `'ai'`, ...) or
|
||||
as `BaseMessage` classes (e.g. `SystemMessage`, `HumanMessage`,
|
||||
`AIMessage`, ...). Can be a single type or a list of types.
|
||||
If `strategy='last'` then the last partial contents of a message are
|
||||
included. If `strategy='first'` then the first partial contents of a
|
||||
message are included.
|
||||
end_on: The message type to end on.
|
||||
|
||||
If specified then every message after the last occurrence of this type is
|
||||
ignored. If `strategy='last'` then this is done before we attempt to get the
|
||||
last `max_tokens`. If `strategy='first'` then this is done after we get the
|
||||
first `max_tokens`. Can be specified as string names (e.g. `'system'`,
|
||||
`'human'`, `'ai'`, ...) or as `BaseMessage` classes (e.g. `SystemMessage`,
|
||||
`HumanMessage`, `AIMessage`, ...). Can be a single type or a list of types.
|
||||
|
||||
start_on: The message type to start on.
|
||||
|
||||
Should only be specified if `strategy='last'`. If specified then every
|
||||
message before the first occurrence of this type is ignored. This is done
|
||||
after we trim the initial messages to the last `max_tokens`. Does not apply
|
||||
to a `SystemMessage` at index 0 if `include_system=True`. Can be specified
|
||||
as string names (e.g. `'system'`, `'human'`, `'ai'`, ...) or as
|
||||
`BaseMessage` classes (e.g. `SystemMessage`, `HumanMessage`, `AIMessage`,
|
||||
...). Can be a single type or a list of types.
|
||||
|
||||
include_system: Whether to keep the `SystemMessage` if there is one at index
|
||||
`0`. Should only be specified if `strategy="last"`.
|
||||
`0`.
|
||||
|
||||
Should only be specified if `strategy="last"`.
|
||||
text_splitter: Function or `langchain_text_splitters.TextSplitter` for
|
||||
splitting the string contents of a message. Only used if
|
||||
`allow_partial=True`. If `strategy='last'` then the last split tokens
|
||||
from a partial message will be included. if `strategy='first'` then the
|
||||
first split tokens from a partial message will be included. Token splitter
|
||||
assumes that separators are kept, so that split contents can be directly
|
||||
concatenated to recreate the original text. Defaults to splitting on
|
||||
newlines.
|
||||
splitting the string contents of a message.
|
||||
|
||||
Only used if `allow_partial=True`. If `strategy='last'` then the last split
|
||||
tokens from a partial message will be included. if `strategy='first'` then
|
||||
the first split tokens from a partial message will be included. Token
|
||||
splitter assumes that separators are kept, so that split contents can be
|
||||
directly concatenated to recreate the original text. Defaults to splitting
|
||||
on newlines.
|
||||
|
||||
Returns:
|
||||
List of trimmed `BaseMessage`.
|
||||
@@ -815,8 +828,8 @@ def trim_messages(
|
||||
|
||||
Example:
|
||||
Trim chat history based on token count, keeping the `SystemMessage` if
|
||||
present, and ensuring that the chat history starts with a `HumanMessage` (
|
||||
or a `SystemMessage` followed by a `HumanMessage`).
|
||||
present, and ensuring that the chat history starts with a `HumanMessage` (or a
|
||||
`SystemMessage` followed by a `HumanMessage`).
|
||||
|
||||
```python
|
||||
from langchain_core.messages import (
|
||||
@@ -869,8 +882,34 @@ def trim_messages(
|
||||
]
|
||||
```
|
||||
|
||||
Trim chat history using approximate token counting with `'approximate'`:
|
||||
|
||||
```python
|
||||
trim_messages(
|
||||
messages,
|
||||
max_tokens=45,
|
||||
strategy="last",
|
||||
# Using the "approximate" shortcut for fast token counting
|
||||
token_counter="approximate",
|
||||
start_on="human",
|
||||
include_system=True,
|
||||
)
|
||||
|
||||
# This is equivalent to using `count_tokens_approximately` directly
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
|
||||
trim_messages(
|
||||
messages,
|
||||
max_tokens=45,
|
||||
strategy="last",
|
||||
token_counter=count_tokens_approximately,
|
||||
start_on="human",
|
||||
include_system=True,
|
||||
)
|
||||
```
|
||||
|
||||
Trim chat history based on the message count, keeping the `SystemMessage` if
|
||||
present, and ensuring that the chat history starts with a `HumanMessage` (
|
||||
present, and ensuring that the chat history starts with a HumanMessage (
|
||||
or a `SystemMessage` followed by a `HumanMessage`).
|
||||
|
||||
trim_messages(
|
||||
@@ -992,24 +1031,44 @@ def trim_messages(
|
||||
raise ValueError(msg)
|
||||
|
||||
messages = convert_to_messages(messages)
|
||||
if hasattr(token_counter, "get_num_tokens_from_messages"):
|
||||
list_token_counter = token_counter.get_num_tokens_from_messages
|
||||
elif callable(token_counter):
|
||||
|
||||
# Handle string shortcuts for token counter
|
||||
if isinstance(token_counter, str):
|
||||
if token_counter in _TOKEN_COUNTER_SHORTCUTS:
|
||||
actual_token_counter = _TOKEN_COUNTER_SHORTCUTS[token_counter]
|
||||
else:
|
||||
available_shortcuts = ", ".join(
|
||||
f"'{key}'" for key in _TOKEN_COUNTER_SHORTCUTS
|
||||
)
|
||||
msg = (
|
||||
f"Invalid token_counter shortcut '{token_counter}'. "
|
||||
f"Available shortcuts: {available_shortcuts}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
# Type narrowing: at this point token_counter is not a str
|
||||
actual_token_counter = token_counter # type: ignore[assignment]
|
||||
|
||||
if hasattr(actual_token_counter, "get_num_tokens_from_messages"):
|
||||
list_token_counter = actual_token_counter.get_num_tokens_from_messages
|
||||
elif callable(actual_token_counter):
|
||||
if (
|
||||
next(iter(inspect.signature(token_counter).parameters.values())).annotation
|
||||
next(
|
||||
iter(inspect.signature(actual_token_counter).parameters.values())
|
||||
).annotation
|
||||
is BaseMessage
|
||||
):
|
||||
|
||||
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
|
||||
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
|
||||
return sum(actual_token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
|
||||
|
||||
else:
|
||||
list_token_counter = token_counter
|
||||
list_token_counter = actual_token_counter
|
||||
else:
|
||||
msg = (
|
||||
f"'token_counter' expected to be a model that implements "
|
||||
f"'get_num_tokens_from_messages()' or a function. Received object of type "
|
||||
f"{type(token_counter)}."
|
||||
f"{type(actual_token_counter)}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -1807,3 +1866,14 @@ def count_tokens_approximately(
|
||||
|
||||
# round up once more time in case extra_tokens_per_message is a float
|
||||
return math.ceil(token_count)
|
||||
|
||||
|
||||
# Mapping from string shortcuts to token counter functions
|
||||
def _approximate_token_counter(messages: Sequence[BaseMessage]) -> int:
|
||||
"""Wrapper for `count_tokens_approximately` that matches expected signature."""
|
||||
return count_tokens_approximately(messages)
|
||||
|
||||
|
||||
_TOKEN_COUNTER_SHORTCUTS = {
|
||||
"approximate": _approximate_token_counter,
|
||||
}
|
||||
|
||||
@@ -673,6 +673,82 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
|
||||
assert messages == messages_copy
|
||||
|
||||
|
||||
def test_trim_messages_token_counter_shortcut_approximate() -> None:
|
||||
"""Test that `'approximate'` shortcut works for `token_counter`."""
|
||||
messages = [
|
||||
SystemMessage("This is a test message"),
|
||||
HumanMessage("Another test message", id="first"),
|
||||
AIMessage("AI response here", id="second"),
|
||||
]
|
||||
messages_copy = [m.model_copy(deep=True) for m in messages]
|
||||
|
||||
# Test using the "approximate" shortcut
|
||||
result_shortcut = trim_messages(
|
||||
messages,
|
||||
max_tokens=50,
|
||||
token_counter="approximate",
|
||||
strategy="last",
|
||||
)
|
||||
|
||||
# Test using count_tokens_approximately directly
|
||||
result_direct = trim_messages(
|
||||
messages,
|
||||
max_tokens=50,
|
||||
token_counter=count_tokens_approximately,
|
||||
strategy="last",
|
||||
)
|
||||
|
||||
# Both should produce the same result
|
||||
assert result_shortcut == result_direct
|
||||
assert messages == messages_copy
|
||||
|
||||
|
||||
def test_trim_messages_token_counter_shortcut_invalid() -> None:
|
||||
"""Test that invalid `token_counter` shortcut raises `ValueError`."""
|
||||
messages = [
|
||||
SystemMessage("This is a test message"),
|
||||
HumanMessage("Another test message"),
|
||||
]
|
||||
|
||||
# Test with invalid shortcut - intentionally passing invalid string to verify
|
||||
# runtime error handling for dynamically-constructed inputs
|
||||
with pytest.raises(ValueError, match="Invalid token_counter shortcut 'invalid'"):
|
||||
trim_messages( # type: ignore[call-overload]
|
||||
messages,
|
||||
max_tokens=50,
|
||||
token_counter="invalid",
|
||||
strategy="last",
|
||||
)
|
||||
|
||||
|
||||
def test_trim_messages_token_counter_shortcut_with_options() -> None:
|
||||
"""Test that `'approximate'` shortcut works with different trim options."""
|
||||
messages = [
|
||||
SystemMessage("System instructions"),
|
||||
HumanMessage("First human message", id="first"),
|
||||
AIMessage("First AI response", id="ai1"),
|
||||
HumanMessage("Second human message", id="second"),
|
||||
AIMessage("Second AI response", id="ai2"),
|
||||
]
|
||||
messages_copy = [m.model_copy(deep=True) for m in messages]
|
||||
|
||||
# Test with various options
|
||||
result = trim_messages(
|
||||
messages,
|
||||
max_tokens=100,
|
||||
token_counter="approximate",
|
||||
strategy="last",
|
||||
include_system=True,
|
||||
start_on="human",
|
||||
)
|
||||
|
||||
# Should include system message and start on human
|
||||
assert len(result) >= 2
|
||||
assert isinstance(result[0], SystemMessage)
|
||||
assert any(isinstance(msg, HumanMessage) for msg in result[1:])
|
||||
assert messages == messages_copy
|
||||
|
||||
|
||||
class FakeTokenCountingModel(FakeChatModel):
|
||||
@override
|
||||
def get_num_tokens_from_messages(
|
||||
|
||||
Reference in New Issue
Block a user