mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
Add 'get_token_ids' method (#4784)
Let user inspect the token ids in addition to getting th enumber of tokens --------- Co-authored-by: Zach Schillaci <40636930+zachschillaci27@users.noreply.github.com>
This commit is contained in:
parent
ef7d015be5
commit
785502edb3
@ -10,8 +10,8 @@ from langchain.callbacks.manager import Callbacks
|
|||||||
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
|
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
|
||||||
|
|
||||||
|
|
||||||
def _get_num_tokens_default_method(text: str) -> int:
|
def _get_token_ids_default_method(text: str) -> List[int]:
|
||||||
"""Get the number of tokens present in the text."""
|
"""Encode the text into token IDs."""
|
||||||
# TODO: this method may not be exact.
|
# TODO: this method may not be exact.
|
||||||
# TODO: this method may differ based on model (eg codex).
|
# TODO: this method may differ based on model (eg codex).
|
||||||
try:
|
try:
|
||||||
@ -19,17 +19,14 @@ def _get_num_tokens_default_method(text: str) -> int:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not import transformers python package. "
|
"Could not import transformers python package. "
|
||||||
"This is needed in order to calculate get_num_tokens. "
|
"This is needed in order to calculate get_token_ids. "
|
||||||
"Please install it with `pip install transformers`."
|
"Please install it with `pip install transformers`."
|
||||||
)
|
)
|
||||||
# create a GPT-2 tokenizer instance
|
# create a GPT-2 tokenizer instance
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
|
|
||||||
# tokenize the text using the GPT-2 tokenizer
|
# tokenize the text using the GPT-2 tokenizer
|
||||||
tokenized_text = tokenizer.tokenize(text)
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
# calculate the number of tokens in the tokenized text
|
|
||||||
return len(tokenized_text)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLanguageModel(BaseModel, ABC):
|
class BaseLanguageModel(BaseModel, ABC):
|
||||||
@ -61,9 +58,13 @@ class BaseLanguageModel(BaseModel, ABC):
|
|||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
"""Predict message from messages."""
|
"""Predict message from messages."""
|
||||||
|
|
||||||
|
def get_token_ids(self, text: str) -> List[int]:
|
||||||
|
"""Get the token present in the text."""
|
||||||
|
return _get_token_ids_default_method(text)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
"""Get the number of tokens present in the text."""
|
"""Get the number of tokens present in the text."""
|
||||||
return _get_num_tokens_default_method(text)
|
return len(self.get_token_ids(text))
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
"""Get the number of tokens in the message."""
|
"""Get the number of tokens in the message."""
|
||||||
|
@ -3,7 +3,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import Extra, Field, root_validator
|
from pydantic import Extra, Field, root_validator
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
@ -30,9 +40,24 @@ from langchain.schema import (
|
|||||||
)
|
)
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _import_tiktoken() -> Any:
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import tiktoken python package. "
|
||||||
|
"This is needed in order to calculate get_token_ids. "
|
||||||
|
"Please install it with `pip install tiktoken`."
|
||||||
|
)
|
||||||
|
return tiktoken
|
||||||
|
|
||||||
|
|
||||||
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]:
|
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
@ -354,42 +379,8 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
"""Return type of chat model."""
|
"""Return type of chat model."""
|
||||||
return "openai-chat"
|
return "openai-chat"
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
|
||||||
"""Calculate num tokens with tiktoken package."""
|
tiktoken_ = _import_tiktoken()
|
||||||
# tiktoken NOT supported for Python 3.7 or below
|
|
||||||
if sys.version_info[1] <= 7:
|
|
||||||
return super().get_num_tokens(text)
|
|
||||||
try:
|
|
||||||
import tiktoken
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import tiktoken python package. "
|
|
||||||
"This is needed in order to calculate get_num_tokens. "
|
|
||||||
"Please install it with `pip install tiktoken`."
|
|
||||||
)
|
|
||||||
# create a GPT-3.5-Turbo encoder instance
|
|
||||||
enc = tiktoken.encoding_for_model(self.model_name)
|
|
||||||
|
|
||||||
# encode the text using the GPT-3.5-Turbo encoder
|
|
||||||
tokenized_text = enc.encode(text)
|
|
||||||
|
|
||||||
# calculate the number of tokens in the encoded text
|
|
||||||
return len(tokenized_text)
|
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
|
||||||
|
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
|
||||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
|
||||||
try:
|
|
||||||
import tiktoken
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import tiktoken python package. "
|
|
||||||
"This is needed in order to calculate get_num_tokens. "
|
|
||||||
"Please install it with `pip install tiktoken`."
|
|
||||||
)
|
|
||||||
|
|
||||||
model = self.model_name
|
model = self.model_name
|
||||||
if model == "gpt-3.5-turbo":
|
if model == "gpt-3.5-turbo":
|
||||||
# gpt-3.5-turbo may change over time.
|
# gpt-3.5-turbo may change over time.
|
||||||
@ -399,14 +390,31 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
# gpt-4 may change over time.
|
# gpt-4 may change over time.
|
||||||
# Returning num tokens assuming gpt-4-0314.
|
# Returning num tokens assuming gpt-4-0314.
|
||||||
model = "gpt-4-0314"
|
model = "gpt-4-0314"
|
||||||
|
|
||||||
# Returns the number of tokens used by a list of messages.
|
# Returns the number of tokens used by a list of messages.
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken_.encoding_for_model(model)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
model = "cl100k_base"
|
||||||
|
encoding = tiktoken_.get_encoding(model)
|
||||||
|
return model, encoding
|
||||||
|
|
||||||
|
def get_token_ids(self, text: str) -> List[int]:
|
||||||
|
"""Get the tokens present in the text with tiktoken package."""
|
||||||
|
# tiktoken NOT supported for Python 3.7 or below
|
||||||
|
if sys.version_info[1] <= 7:
|
||||||
|
return super().get_token_ids(text)
|
||||||
|
_, encoding_model = self._get_encoding_model()
|
||||||
|
return encoding_model.encode(text)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||||
|
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||||
|
if sys.version_info[1] <= 7:
|
||||||
|
return super().get_num_tokens_from_messages(messages)
|
||||||
|
model, encoding = self._get_encoding_model()
|
||||||
if model == "gpt-3.5-turbo-0301":
|
if model == "gpt-3.5-turbo-0301":
|
||||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||||
tokens_per_message = 4
|
tokens_per_message = 4
|
||||||
|
@ -454,8 +454,8 @@ class BaseOpenAI(BaseLLM):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "openai"
|
return "openai"
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_token_ids(self, text: str) -> List[int]:
|
||||||
"""Calculate num tokens with tiktoken package."""
|
"""Get the token IDs using the tiktoken package."""
|
||||||
# tiktoken NOT supported for Python < 3.8
|
# tiktoken NOT supported for Python < 3.8
|
||||||
if sys.version_info[1] < 8:
|
if sys.version_info[1] < 8:
|
||||||
return super().get_num_tokens(text)
|
return super().get_num_tokens(text)
|
||||||
@ -470,15 +470,12 @@ class BaseOpenAI(BaseLLM):
|
|||||||
|
|
||||||
enc = tiktoken.encoding_for_model(self.model_name)
|
enc = tiktoken.encoding_for_model(self.model_name)
|
||||||
|
|
||||||
tokenized_text = enc.encode(
|
return enc.encode(
|
||||||
text,
|
text,
|
||||||
allowed_special=self.allowed_special,
|
allowed_special=self.allowed_special,
|
||||||
disallowed_special=self.disallowed_special,
|
disallowed_special=self.disallowed_special,
|
||||||
)
|
)
|
||||||
|
|
||||||
# calculate the number of tokens in the encoded text
|
|
||||||
return len(tokenized_text)
|
|
||||||
|
|
||||||
def modelname_to_contextsize(self, modelname: str) -> int:
|
def modelname_to_contextsize(self, modelname: str) -> int:
|
||||||
"""Calculate the maximum number of tokens possible to generate for a model.
|
"""Calculate the maximum number of tokens possible to generate for a model.
|
||||||
|
|
||||||
@ -802,11 +799,11 @@ class OpenAIChat(BaseLLM):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "openai-chat"
|
return "openai-chat"
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_token_ids(self, text: str) -> List[int]:
|
||||||
"""Calculate num tokens with tiktoken package."""
|
"""Get the token IDs using the tiktoken package."""
|
||||||
# tiktoken NOT supported for Python < 3.8
|
# tiktoken NOT supported for Python < 3.8
|
||||||
if sys.version_info[1] < 8:
|
if sys.version_info[1] < 8:
|
||||||
return super().get_num_tokens(text)
|
return super().get_token_ids(text)
|
||||||
try:
|
try:
|
||||||
import tiktoken
|
import tiktoken
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -815,15 +812,10 @@ class OpenAIChat(BaseLLM):
|
|||||||
"This is needed in order to calculate get_num_tokens. "
|
"This is needed in order to calculate get_num_tokens. "
|
||||||
"Please install it with `pip install tiktoken`."
|
"Please install it with `pip install tiktoken`."
|
||||||
)
|
)
|
||||||
# create a GPT-3.5-Turbo encoder instance
|
|
||||||
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
|
||||||
|
|
||||||
# encode the text using the GPT-3.5-Turbo encoder
|
enc = tiktoken.encoding_for_model(self.model_name)
|
||||||
tokenized_text = enc.encode(
|
return enc.encode(
|
||||||
text,
|
text,
|
||||||
allowed_special=self.allowed_special,
|
allowed_special=self.allowed_special,
|
||||||
disallowed_special=self.disallowed_special,
|
disallowed_special=self.disallowed_special,
|
||||||
)
|
)
|
||||||
|
|
||||||
# calculate the number of tokens in the encoded text
|
|
||||||
return len(tokenized_text)
|
|
||||||
|
@ -6,6 +6,7 @@ from typing import Generator
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
from langchain.llms.loading import load_llm
|
from langchain.llms.loading import load_llm
|
||||||
from langchain.llms.openai import OpenAI, OpenAIChat
|
from langchain.llms.openai import OpenAI, OpenAIChat
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
@ -237,3 +238,40 @@ def test_openai_modelname_to_contextsize_invalid() -> None:
|
|||||||
"""Test model name to context size on an invalid model."""
|
"""Test model name to context size on an invalid model."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
OpenAI().modelname_to_contextsize("foobar")
|
OpenAI().modelname_to_contextsize("foobar")
|
||||||
|
|
||||||
|
|
||||||
|
_EXPECTED_NUM_TOKENS = {
|
||||||
|
"ada": 17,
|
||||||
|
"babbage": 17,
|
||||||
|
"curie": 17,
|
||||||
|
"davinci": 17,
|
||||||
|
"gpt-4": 12,
|
||||||
|
"gpt-4-32k": 12,
|
||||||
|
"gpt-3.5-turbo": 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
_MODELS = models = [
|
||||||
|
"ada",
|
||||||
|
"babbage",
|
||||||
|
"curie",
|
||||||
|
"davinci",
|
||||||
|
]
|
||||||
|
_CHAT_MODELS = [
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-32k",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", _MODELS)
|
||||||
|
def test_openai_get_num_tokens(model: str) -> None:
|
||||||
|
"""Test get_tokens."""
|
||||||
|
llm = OpenAI(model=model)
|
||||||
|
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", _CHAT_MODELS)
|
||||||
|
def test_chat_openai_get_num_tokens(model: str) -> None:
|
||||||
|
"""Test get_tokens."""
|
||||||
|
llm = ChatOpenAI(model=model)
|
||||||
|
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
||||||
|
@ -1,15 +1,19 @@
|
|||||||
"""Test formatting functionality."""
|
"""Test formatting functionality."""
|
||||||
|
|
||||||
from langchain.base_language import _get_num_tokens_default_method
|
from langchain.base_language import _get_token_ids_default_method
|
||||||
|
|
||||||
|
|
||||||
class TestTokenCountingWithGPT2Tokenizer:
|
class TestTokenCountingWithGPT2Tokenizer:
|
||||||
|
def test_tokenization(self) -> None:
|
||||||
|
# Check that the tokenization is consistent with the GPT-2 tokenizer
|
||||||
|
assert _get_token_ids_default_method("This is a test") == [1212, 318, 257, 1332]
|
||||||
|
|
||||||
def test_empty_token(self) -> None:
|
def test_empty_token(self) -> None:
|
||||||
assert _get_num_tokens_default_method("") == 0
|
assert len(_get_token_ids_default_method("")) == 0
|
||||||
|
|
||||||
def test_multiple_tokens(self) -> None:
|
def test_multiple_tokens(self) -> None:
|
||||||
assert _get_num_tokens_default_method("a b c") == 3
|
assert len(_get_token_ids_default_method("a b c")) == 3
|
||||||
|
|
||||||
def test_special_tokens(self) -> None:
|
def test_special_tokens(self) -> None:
|
||||||
# test for consistency when the default tokenizer is changed
|
# test for consistency when the default tokenizer is changed
|
||||||
assert _get_num_tokens_default_method("a:b_c d") == 6
|
assert len(_get_token_ids_default_method("a:b_c d")) == 6
|
||||||
|
Loading…
Reference in New Issue
Block a user