Tiktoken override (#6697)

This commit is contained in:
Harrison Chase 2023-06-26 00:49:32 -07:00 committed by GitHub
parent f9771700e4
commit e9877ea8b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 12 deletions

View File

@ -184,6 +184,16 @@ class ChatOpenAI(BaseChatModel):
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
class Config:
"""Configuration for this pydantic object."""
@ -448,15 +458,18 @@ class ChatOpenAI(BaseChatModel):
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
tiktoken_ = _import_tiktoken()
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model(model)

View File

@ -170,6 +170,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout in seconds for the OpenAPI request."""
headers: Any = None
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
class Config:
"""Configuration for this pydantic object."""
@ -265,7 +275,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
tokens = []
indices = []
encoding = tiktoken.model.encoding_for_model(self.model)
model_name = self.tiktoken_model_name or self.model
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, text in enumerate(texts):
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
@ -329,7 +345,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
tokens = []
indices = []
encoding = tiktoken.model.encoding_for_model(self.model)
model_name = self.tiktoken_model_name or self.model
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, text in enumerate(texts):
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500

View File

@ -171,6 +171,16 @@ class BaseOpenAI(BaseLLM):
"""Set of special tokens that are allowed。"""
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
"""Set of special tokens that are not allowed。"""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
"""Initialize the OpenAI object."""
@ -491,7 +501,13 @@ class BaseOpenAI(BaseLLM):
"Please install it with `pip install tiktoken`."
)
enc = tiktoken.encoding_for_model(self.model_name)
model_name = self.tiktoken_model_name or self.model_name
try:
enc = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
enc = tiktoken.get_encoding(model)
return enc.encode(
text,