diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 1fef01a6859..daf6e69cdf9 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -2,7 +2,6 @@ from __future__ import annotations -import warnings from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence from functools import cache @@ -26,7 +25,6 @@ from langchain_core.messages import ( AnyMessage, BaseMessage, MessageLikeRepresentation, - get_buffer_string, ) from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.prompt_values import PromptValue @@ -368,33 +366,6 @@ class BaseLanguageModel( """ return len(self.get_token_ids(text)) - def get_num_tokens_from_messages( - self, - messages: list[BaseMessage], - tools: Optional[Sequence] = None, - ) -> int: - """Get the number of tokens in the messages. - - Useful for checking if an input fits in a model's context window. - - **Note**: the base implementation of get_num_tokens_from_messages ignores - tool schemas. - - Args: - messages: The message inputs to tokenize. - tools: If provided, sequence of dict, BaseModel, function, or BaseTools - to be converted to tool schemas. - - Returns: - The sum of the number of tokens across the messages. - """ - if tools is not None: - warnings.warn( - "Counting tokens in tool schemas is not yet supported. Ignoring tools.", - stacklevel=2, - ) - return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages) - @classmethod def _all_required_field_names(cls) -> set: """DEPRECATED: Kept for backwards compatibility. diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index d7fc51d76d3..d836c80f6b1 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -55,6 +55,7 @@ from langchain_core.messages import ( HumanMessage, convert_to_messages, convert_to_openai_image_block, + get_buffer_string, is_data_content_block, message_chunk_to_message, ) @@ -1351,6 +1352,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): starter_dict["_type"] = self._llm_type return starter_dict + def get_num_tokens_from_messages( + self, + messages: list[BaseMessage], + tools: Optional[Sequence] = None, + ) -> int: + """Get the number of tokens in the messages. + + Useful for checking if an input fits in a model's context window. + + **Note**: the base implementation of get_num_tokens_from_messages ignores + tool schemas. + + Args: + messages: The message inputs to tokenize. + tools: If provided, sequence of dict, BaseModel, function, or BaseTools + to be converted to tool schemas. + + Returns: + The sum of the number of tokens across the messages. + """ + if tools is not None: + warnings.warn( + "Counting tokens in tool schemas is not yet supported. Ignoring tools.", + stacklevel=2, + ) + return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages) + def bind_tools( self, tools: Sequence[ diff --git a/libs/core/langchain_core/language_models/v1/chat_models.py b/libs/core/langchain_core/language_models/v1/chat_models.py index 437fb876496..ac0c5c04238 100644 --- a/libs/core/langchain_core/language_models/v1/chat_models.py +++ b/libs/core/langchain_core/language_models/v1/chat_models.py @@ -4,6 +4,7 @@ from __future__ import annotations import copy import typing +import warnings from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator, Sequence from operator import itemgetter @@ -40,9 +41,13 @@ from langchain_core.messages import ( AIMessage, BaseMessage, convert_to_openai_image_block, + get_buffer_string, is_data_content_block, ) -from langchain_core.messages.utils import convert_to_messages_v1 +from langchain_core.messages.utils import ( + _convert_from_v1_message, + convert_to_messages_v1, +) from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 from langchain_core.messages.v1 import HumanMessage as HumanMessageV1 @@ -899,6 +904,34 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC): return RunnableMap(raw=llm) | parser_with_fallback return llm | output_parser + def get_num_tokens_from_messages( + self, + messages: list[MessageV1], + tools: Optional[Sequence] = None, + ) -> int: + """Get the number of tokens in the messages. + + Useful for checking if an input fits in a model's context window. + + **Note**: the base implementation of get_num_tokens_from_messages ignores + tool schemas. + + Args: + messages: The message inputs to tokenize. + tools: If provided, sequence of dict, BaseModel, function, or BaseTools + to be converted to tool schemas. + + Returns: + The sum of the number of tokens across the messages. + """ + messages = [_convert_from_v1_message(message) for message in messages] + if tools is not None: + warnings.warn( + "Counting tokens in tool schemas is not yet supported. Ignoring tools.", + stacklevel=2, + ) + return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages) + def _gen_info_and_msg_metadata( generation: Union[ChatGeneration, ChatGenerationChunk],