mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 12:38:45 +00:00
Harrison/hf lru (#10154)
Co-authored-by: Pascal Bro <git@pascalbrokmeier.de> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
4765c09703
commit
794ff2dae8
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -23,10 +24,8 @@ if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
|
||||
|
||||
def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
"""Encode the text into token IDs."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
||||
def get_tokenizer() -> Any:
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
@ -36,7 +35,13 @@ def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-2 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
return GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
|
||||
def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
"""Encode the text into token IDs."""
|
||||
# get the cached tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# tokenize the text using the GPT-2 tokenizer
|
||||
return tokenizer.encode(text)
|
||||
|
Loading…
Reference in New Issue
Block a user