mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
move get_num_tokens_from_messages to BaseChatModel and BaseChatModelV1
This commit is contained in:
parent
c409f723a2
commit
b8fed06409
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cache
|
||||
@ -26,7 +25,6 @@ from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
MessageLikeRepresentation,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
@ -368,33 +366,6 @@ class BaseLanguageModel(
|
||||
"""
|
||||
return len(self.get_token_ids(text))
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
tools: Optional[Sequence] = None,
|
||||
) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||
tool schemas.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
||||
|
||||
@classmethod
|
||||
def _all_required_field_names(cls) -> set:
|
||||
"""DEPRECATED: Kept for backwards compatibility.
|
||||
|
@ -55,6 +55,7 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
convert_to_messages,
|
||||
convert_to_openai_image_block,
|
||||
get_buffer_string,
|
||||
is_data_content_block,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
@ -1351,6 +1352,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
starter_dict["_type"] = self._llm_type
|
||||
return starter_dict
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
tools: Optional[Sequence] = None,
|
||||
) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||
tool schemas.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from operator import itemgetter
|
||||
@ -40,9 +41,13 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
convert_to_openai_image_block,
|
||||
get_buffer_string,
|
||||
is_data_content_block,
|
||||
)
|
||||
from langchain_core.messages.utils import convert_to_messages_v1
|
||||
from langchain_core.messages.utils import (
|
||||
_convert_from_v1_message,
|
||||
convert_to_messages_v1,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||
@ -899,6 +904,34 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
return llm | output_parser
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
tools: Optional[Sequence] = None,
|
||||
) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||
tool schemas.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
messages = [_convert_from_v1_message(message) for message in messages]
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
||||
|
||||
|
||||
def _gen_info_and_msg_metadata(
|
||||
generation: Union[ChatGeneration, ChatGenerationChunk],
|
||||
|
Loading…
Reference in New Issue
Block a user