mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 21:47:12 +00:00
Add 'get_token_ids' method (#4784)
Let user inspect the token ids in addition to getting th enumber of tokens --------- Co-authored-by: Zach Schillaci <40636930+zachschillaci27@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Generator
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.llms.loading import load_llm
|
||||
from langchain.llms.openai import OpenAI, OpenAIChat
|
||||
from langchain.schema import LLMResult
|
||||
@@ -237,3 +238,40 @@ def test_openai_modelname_to_contextsize_invalid() -> None:
|
||||
"""Test model name to context size on an invalid model."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI().modelname_to_contextsize("foobar")
|
||||
|
||||
|
||||
_EXPECTED_NUM_TOKENS = {
|
||||
"ada": 17,
|
||||
"babbage": 17,
|
||||
"curie": 17,
|
||||
"davinci": 17,
|
||||
"gpt-4": 12,
|
||||
"gpt-4-32k": 12,
|
||||
"gpt-3.5-turbo": 12,
|
||||
}
|
||||
|
||||
_MODELS = models = [
|
||||
"ada",
|
||||
"babbage",
|
||||
"curie",
|
||||
"davinci",
|
||||
]
|
||||
_CHAT_MODELS = [
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-3.5-turbo",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", _MODELS)
|
||||
def test_openai_get_num_tokens(model: str) -> None:
|
||||
"""Test get_tokens."""
|
||||
llm = OpenAI(model=model)
|
||||
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", _CHAT_MODELS)
|
||||
def test_chat_openai_get_num_tokens(model: str) -> None:
|
||||
"""Test get_tokens."""
|
||||
llm = ChatOpenAI(model=model)
|
||||
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
||||
|
@@ -1,15 +1,19 @@
|
||||
"""Test formatting functionality."""
|
||||
|
||||
from langchain.base_language import _get_num_tokens_default_method
|
||||
from langchain.base_language import _get_token_ids_default_method
|
||||
|
||||
|
||||
class TestTokenCountingWithGPT2Tokenizer:
|
||||
def test_tokenization(self) -> None:
|
||||
# Check that the tokenization is consistent with the GPT-2 tokenizer
|
||||
assert _get_token_ids_default_method("This is a test") == [1212, 318, 257, 1332]
|
||||
|
||||
def test_empty_token(self) -> None:
|
||||
assert _get_num_tokens_default_method("") == 0
|
||||
assert len(_get_token_ids_default_method("")) == 0
|
||||
|
||||
def test_multiple_tokens(self) -> None:
|
||||
assert _get_num_tokens_default_method("a b c") == 3
|
||||
assert len(_get_token_ids_default_method("a b c")) == 3
|
||||
|
||||
def test_special_tokens(self) -> None:
|
||||
# test for consistency when the default tokenizer is changed
|
||||
assert _get_num_tokens_default_method("a:b_c d") == 6
|
||||
assert len(_get_token_ids_default_method("a:b_c d")) == 6
|
||||
|
Reference in New Issue
Block a user