mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 12:06:43 +00:00
move get_num_tokens_from_messages to BaseChatModel and BaseChatModelV1
This commit is contained in:
parent
c409f723a2
commit
b8fed06409
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from functools import cache
|
from functools import cache
|
||||||
@ -26,7 +25,6 @@ from langchain_core.messages import (
|
|||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
MessageLikeRepresentation,
|
MessageLikeRepresentation,
|
||||||
get_buffer_string,
|
|
||||||
)
|
)
|
||||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
@ -368,33 +366,6 @@ class BaseLanguageModel(
|
|||||||
"""
|
"""
|
||||||
return len(self.get_token_ids(text))
|
return len(self.get_token_ids(text))
|
||||||
|
|
||||||
def get_num_tokens_from_messages(
|
|
||||||
self,
|
|
||||||
messages: list[BaseMessage],
|
|
||||||
tools: Optional[Sequence] = None,
|
|
||||||
) -> int:
|
|
||||||
"""Get the number of tokens in the messages.
|
|
||||||
|
|
||||||
Useful for checking if an input fits in a model's context window.
|
|
||||||
|
|
||||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
|
||||||
tool schemas.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The message inputs to tokenize.
|
|
||||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
|
||||||
to be converted to tool schemas.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The sum of the number of tokens across the messages.
|
|
||||||
"""
|
|
||||||
if tools is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _all_required_field_names(cls) -> set:
|
def _all_required_field_names(cls) -> set:
|
||||||
"""DEPRECATED: Kept for backwards compatibility.
|
"""DEPRECATED: Kept for backwards compatibility.
|
||||||
|
@ -55,6 +55,7 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
convert_to_messages,
|
convert_to_messages,
|
||||||
convert_to_openai_image_block,
|
convert_to_openai_image_block,
|
||||||
|
get_buffer_string,
|
||||||
is_data_content_block,
|
is_data_content_block,
|
||||||
message_chunk_to_message,
|
message_chunk_to_message,
|
||||||
)
|
)
|
||||||
@ -1351,6 +1352,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
starter_dict["_type"] = self._llm_type
|
starter_dict["_type"] = self._llm_type
|
||||||
return starter_dict
|
return starter_dict
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
tools: Optional[Sequence] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Get the number of tokens in the messages.
|
||||||
|
|
||||||
|
Useful for checking if an input fits in a model's context window.
|
||||||
|
|
||||||
|
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||||
|
tool schemas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The message inputs to tokenize.
|
||||||
|
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||||
|
to be converted to tool schemas.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sum of the number of tokens across the messages.
|
||||||
|
"""
|
||||||
|
if tools is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
||||||
|
|
||||||
def bind_tools(
|
def bind_tools(
|
||||||
self,
|
self,
|
||||||
tools: Sequence[
|
tools: Sequence[
|
||||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import typing
|
import typing
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -40,9 +41,13 @@ from langchain_core.messages import (
|
|||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
convert_to_openai_image_block,
|
convert_to_openai_image_block,
|
||||||
|
get_buffer_string,
|
||||||
is_data_content_block,
|
is_data_content_block,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.utils import convert_to_messages_v1
|
from langchain_core.messages.utils import (
|
||||||
|
_convert_from_v1_message,
|
||||||
|
convert_to_messages_v1,
|
||||||
|
)
|
||||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||||
@ -899,6 +904,34 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
|||||||
return RunnableMap(raw=llm) | parser_with_fallback
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
return llm | output_parser
|
return llm | output_parser
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
tools: Optional[Sequence] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Get the number of tokens in the messages.
|
||||||
|
|
||||||
|
Useful for checking if an input fits in a model's context window.
|
||||||
|
|
||||||
|
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||||
|
tool schemas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The message inputs to tokenize.
|
||||||
|
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||||
|
to be converted to tool schemas.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sum of the number of tokens across the messages.
|
||||||
|
"""
|
||||||
|
messages = [_convert_from_v1_message(message) for message in messages]
|
||||||
|
if tools is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
def _gen_info_and_msg_metadata(
|
def _gen_info_and_msg_metadata(
|
||||||
generation: Union[ChatGeneration, ChatGenerationChunk],
|
generation: Union[ChatGeneration, ChatGenerationChunk],
|
||||||
|
Loading…
Reference in New Issue
Block a user