mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
refactor(core): add warning for fallback GPT-2 tokenizer usage (#34621)
This commit is contained in:
@@ -87,13 +87,28 @@ def get_tokenizer() -> Any:
|
||||
return GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
|
||||
_GPT2_TOKENIZER_WARNED = False
|
||||
|
||||
|
||||
def _get_token_ids_default_method(text: str) -> list[int]:
|
||||
"""Encode the text into token IDs."""
|
||||
# get the cached tokenizer
|
||||
"""Encode the text into token IDs using the fallback GPT-2 tokenizer."""
|
||||
global _GPT2_TOKENIZER_WARNED # noqa: PLW0603
|
||||
if not _GPT2_TOKENIZER_WARNED:
|
||||
warnings.warn(
|
||||
"Using fallback GPT-2 tokenizer for token counting. "
|
||||
"Token counts may be inaccurate for non-GPT-2 models. "
|
||||
"For accurate counts, use a model-specific method if available.",
|
||||
stacklevel=3,
|
||||
)
|
||||
_GPT2_TOKENIZER_WARNED = True
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# tokenize the text using the GPT-2 tokenizer
|
||||
return cast("list[int]", tokenizer.encode(text))
|
||||
# Pass verbose=False to suppress the "Token indices sequence length is longer than
|
||||
# the specified maximum sequence length" warning from HuggingFace. This warning is
|
||||
# about GPT-2's 1024 token context limit, but we're only using the tokenizer for
|
||||
# counting, not for model input.
|
||||
return cast("list[int]", tokenizer.encode(text, verbose=False))
|
||||
|
||||
|
||||
LanguageModelInput = PromptValue | str | Sequence[MessageLikeRepresentation]
|
||||
|
||||
Reference in New Issue
Block a user