move get_num_tokens_from_messages to BaseChatModel and BaseChatModelV1

This commit is contained in:
Chester Curme 2025-07-28 10:58:57 -04:00
parent c409f723a2
commit b8fed06409
3 changed files with 62 additions and 30 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from functools import cache
@ -26,7 +25,6 @@ from langchain_core.messages import (
AnyMessage,
BaseMessage,
MessageLikeRepresentation,
get_buffer_string,
)
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.prompt_values import PromptValue
@ -368,33 +366,6 @@ class BaseLanguageModel(
"""
return len(self.get_token_ids(text))
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
@classmethod
def _all_required_field_names(cls) -> set:
"""DEPRECATED: Kept for backwards compatibility.

View File

@ -55,6 +55,7 @@ from langchain_core.messages import (
HumanMessage,
convert_to_messages,
convert_to_openai_image_block,
get_buffer_string,
is_data_content_block,
message_chunk_to_message,
)
@ -1351,6 +1352,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
starter_dict["_type"] = self._llm_type
return starter_dict
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
def bind_tools(
self,
tools: Sequence[

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import copy
import typing
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from operator import itemgetter
@ -40,9 +41,13 @@ from langchain_core.messages import (
AIMessage,
BaseMessage,
convert_to_openai_image_block,
get_buffer_string,
is_data_content_block,
)
from langchain_core.messages.utils import convert_to_messages_v1
from langchain_core.messages.utils import (
_convert_from_v1_message,
convert_to_messages_v1,
)
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
@ -899,6 +904,34 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
return RunnableMap(raw=llm) | parser_with_fallback
return llm | output_parser
def get_num_tokens_from_messages(
self,
messages: list[MessageV1],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
messages = [_convert_from_v1_message(message) for message in messages]
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
def _gen_info_and_msg_metadata(
generation: Union[ChatGeneration, ChatGenerationChunk],