mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
Harrison/token counts (#311)
Co-authored-by: thepok <richterthepok@yahoo.de>
This commit is contained in:
parent
19a9fa16a9
commit
a7c8e37e77
@ -116,9 +116,86 @@ class OpenAI(LLM, BaseModel):
|
|||||||
response = openai("Tell me a joke.")
|
response = openai("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
params = self._default_params
|
params = self._default_params
|
||||||
|
|
||||||
|
if params["max_tokens"] == -1:
|
||||||
|
params["max_tokens"] = self.max_tokens_for_prompt(prompt)
|
||||||
|
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
if "stop" in params:
|
if "stop" in params:
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
response = self.client.create(model=self.model_name, prompt=prompt, **params)
|
response = self.client.create(model=self.model_name, prompt=prompt, **params)
|
||||||
return response["choices"][0]["text"]
|
return response["choices"][0]["text"]
|
||||||
|
|
||||||
|
def modelname_to_contextsize(self, modelname: str) -> int:
|
||||||
|
"""Calculate the maximum number of tokens possible to generate for a model.
|
||||||
|
|
||||||
|
text-davinci-003: 4,000 tokens
|
||||||
|
text-curie-001: 2,048 tokens
|
||||||
|
text-babbage-001: 2,048 tokens
|
||||||
|
text-ada-001: 2,048 tokens
|
||||||
|
code-davinci-002: 8,000 tokens
|
||||||
|
code-cushman-001: 2,048 tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modelname: The modelname we want to know the context size for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The maximum context size
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
max_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
||||||
|
"""
|
||||||
|
if modelname == "text-davinci-003":
|
||||||
|
return 4000
|
||||||
|
elif modelname == "text-curie-001":
|
||||||
|
return 2048
|
||||||
|
elif modelname == "text-babbage-001":
|
||||||
|
return 2048
|
||||||
|
elif modelname == "text-ada-001":
|
||||||
|
return 2048
|
||||||
|
elif modelname == "code-davinci-002":
|
||||||
|
return 8000
|
||||||
|
elif modelname == "code-cushman-001":
|
||||||
|
return 2048
|
||||||
|
else:
|
||||||
|
return 4000
|
||||||
|
|
||||||
|
def max_tokens_for_prompt(self, prompt: str) -> int:
|
||||||
|
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The maximum number of tokens to generate for a prompt.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
max_tokens = openai.max_token_for_prompt("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
# TODO: this method may not be exact.
|
||||||
|
# TODO: this method may differ based on model (eg codex).
|
||||||
|
try:
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import transformers python package. "
|
||||||
|
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||||
|
"Please it install it with `pip install transformers`."
|
||||||
|
)
|
||||||
|
# create a GPT-3 tokenizer instance
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
# tokenize the text using the GPT-3 tokenizer
|
||||||
|
tokenized_text = tokenizer.tokenize(prompt)
|
||||||
|
|
||||||
|
# calculate the number of tokens in the tokenized text
|
||||||
|
num_tokens = len(tokenized_text)
|
||||||
|
|
||||||
|
# get max context size for model by name
|
||||||
|
max_size = self.modelname_to_contextsize(self.model_name)
|
||||||
|
return max_size - num_tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user