mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Harrison/token counts (#311)
Co-authored-by: thepok <richterthepok@yahoo.de>
This commit is contained in:
parent
19a9fa16a9
commit
a7c8e37e77
@ -116,9 +116,86 @@ class OpenAI(LLM, BaseModel):
|
||||
response = openai("Tell me a joke.")
|
||||
"""
|
||||
params = self._default_params
|
||||
|
||||
if params["max_tokens"] == -1:
|
||||
params["max_tokens"] = self.max_tokens_for_prompt(prompt)
|
||||
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
response = self.client.create(model=self.model_name, prompt=prompt, **params)
|
||||
return response["choices"][0]["text"]
|
||||
|
||||
def modelname_to_contextsize(self, modelname: str) -> int:
|
||||
"""Calculate the maximum number of tokens possible to generate for a model.
|
||||
|
||||
text-davinci-003: 4,000 tokens
|
||||
text-curie-001: 2,048 tokens
|
||||
text-babbage-001: 2,048 tokens
|
||||
text-ada-001: 2,048 tokens
|
||||
code-davinci-002: 8,000 tokens
|
||||
code-cushman-001: 2,048 tokens
|
||||
|
||||
Args:
|
||||
modelname: The modelname we want to know the context size for.
|
||||
|
||||
Returns:
|
||||
The maximum context size
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
max_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
||||
"""
|
||||
if modelname == "text-davinci-003":
|
||||
return 4000
|
||||
elif modelname == "text-curie-001":
|
||||
return 2048
|
||||
elif modelname == "text-babbage-001":
|
||||
return 2048
|
||||
elif modelname == "text-ada-001":
|
||||
return 2048
|
||||
elif modelname == "code-davinci-002":
|
||||
return 8000
|
||||
elif modelname == "code-cushman-001":
|
||||
return 2048
|
||||
else:
|
||||
return 4000
|
||||
|
||||
def max_tokens_for_prompt(self, prompt: str) -> int:
|
||||
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
|
||||
Returns:
|
||||
The maximum number of tokens to generate for a prompt.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
max_tokens = openai.max_token_for_prompt("Tell me a joke.")
|
||||
"""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(prompt)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
num_tokens = len(tokenized_text)
|
||||
|
||||
# get max context size for model by name
|
||||
max_size = self.modelname_to_contextsize(self.model_name)
|
||||
return max_size - num_tokens
|
||||
|
Loading…
Reference in New Issue
Block a user