diff --git a/libs/langchain/langchain/schema/language_model.py b/libs/langchain/langchain/schema/language_model.py index 6a46165e43f..8623233807f 100644 --- a/libs/langchain/langchain/schema/language_model.py +++ b/libs/langchain/langchain/schema/language_model.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -23,10 +24,8 @@ if TYPE_CHECKING: from langchain.callbacks.manager import Callbacks -def _get_token_ids_default_method(text: str) -> List[int]: - """Encode the text into token IDs.""" - # TODO: this method may not be exact. - # TODO: this method may differ based on model (eg codex). +@lru_cache(maxsize=None) # Cache the tokenizer +def get_tokenizer() -> Any: try: from transformers import GPT2TokenizerFast except ImportError: @@ -36,7 +35,13 @@ def _get_token_ids_default_method(text: str) -> List[int]: "Please install it with `pip install transformers`." ) # create a GPT-2 tokenizer instance - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + return GPT2TokenizerFast.from_pretrained("gpt2") + + +def _get_token_ids_default_method(text: str) -> List[int]: + """Encode the text into token IDs.""" + # get the cached tokenizer + tokenizer = get_tokenizer() # tokenize the text using the GPT-2 tokenizer return tokenizer.encode(text)