move get_num_tokens_from_messages to BaseChatModel and BaseChatModelV1

This commit is contained in:
Chester Curme 2025-07-28 10:58:57 -04:00
parent c409f723a2
commit b8fed06409
3 changed files with 62 additions and 30 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from functools import cache from functools import cache
@ -26,7 +25,6 @@ from langchain_core.messages import (
AnyMessage, AnyMessage,
BaseMessage, BaseMessage,
MessageLikeRepresentation, MessageLikeRepresentation,
get_buffer_string,
) )
from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
@ -368,33 +366,6 @@ class BaseLanguageModel(
""" """
return len(self.get_token_ids(text)) return len(self.get_token_ids(text))
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
@classmethod @classmethod
def _all_required_field_names(cls) -> set: def _all_required_field_names(cls) -> set:
"""DEPRECATED: Kept for backwards compatibility. """DEPRECATED: Kept for backwards compatibility.

View File

@ -55,6 +55,7 @@ from langchain_core.messages import (
HumanMessage, HumanMessage,
convert_to_messages, convert_to_messages,
convert_to_openai_image_block, convert_to_openai_image_block,
get_buffer_string,
is_data_content_block, is_data_content_block,
message_chunk_to_message, message_chunk_to_message,
) )
@ -1351,6 +1352,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
starter_dict["_type"] = self._llm_type starter_dict["_type"] = self._llm_type
return starter_dict return starter_dict
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
def bind_tools( def bind_tools(
self, self,
tools: Sequence[ tools: Sequence[

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import copy import copy
import typing import typing
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence from collections.abc import AsyncIterator, Iterator, Sequence
from operator import itemgetter from operator import itemgetter
@ -40,9 +41,13 @@ from langchain_core.messages import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
convert_to_openai_image_block, convert_to_openai_image_block,
get_buffer_string,
is_data_content_block, is_data_content_block,
) )
from langchain_core.messages.utils import convert_to_messages_v1 from langchain_core.messages.utils import (
_convert_from_v1_message,
convert_to_messages_v1,
)
from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1 from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
@ -899,6 +904,34 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
return RunnableMap(raw=llm) | parser_with_fallback return RunnableMap(raw=llm) | parser_with_fallback
return llm | output_parser return llm | output_parser
def get_num_tokens_from_messages(
self,
messages: list[MessageV1],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
messages = [_convert_from_v1_message(message) for message in messages]
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
def _gen_info_and_msg_metadata( def _gen_info_and_msg_metadata(
generation: Union[ChatGeneration, ChatGenerationChunk], generation: Union[ChatGeneration, ChatGenerationChunk],