From 794ff2dae845cd22b5e939db4ef4e005707aa941 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 3 Sep 2023 15:39:25 -0700 Subject: [PATCH] Harrison/hf lru (#10154) Co-authored-by: Pascal Bro Co-authored-by: Bagatur --- libs/langchain/langchain/schema/language_model.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/schema/language_model.py b/libs/langchain/langchain/schema/language_model.py index 6a46165e43f..8623233807f 100644 --- a/libs/langchain/langchain/schema/language_model.py +++ b/libs/langchain/langchain/schema/language_model.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -23,10 +24,8 @@ if TYPE_CHECKING: from langchain.callbacks.manager import Callbacks -def _get_token_ids_default_method(text: str) -> List[int]: - """Encode the text into token IDs.""" - # TODO: this method may not be exact. - # TODO: this method may differ based on model (eg codex). +@lru_cache(maxsize=None) # Cache the tokenizer +def get_tokenizer() -> Any: try: from transformers import GPT2TokenizerFast except ImportError: @@ -36,7 +35,13 @@ def _get_token_ids_default_method(text: str) -> List[int]: "Please install it with `pip install transformers`." ) # create a GPT-2 tokenizer instance - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + return GPT2TokenizerFast.from_pretrained("gpt2") + + +def _get_token_ids_default_method(text: str) -> List[int]: + """Encode the text into token IDs.""" + # get the cached tokenizer + tokenizer = get_tokenizer() # tokenize the text using the GPT-2 tokenizer return tokenizer.encode(text)