mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
Tiktoken override (#6697)
This commit is contained in:
parent
f9771700e4
commit
e9877ea8b1
@ -184,6 +184,16 @@ class ChatOpenAI(BaseChatModel):
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -448,15 +458,18 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
|
||||
tiktoken_ = _import_tiktoken()
|
||||
model = self.model_name
|
||||
if model == "gpt-3.5-turbo":
|
||||
# gpt-3.5-turbo may change over time.
|
||||
# Returning num tokens assuming gpt-3.5-turbo-0301.
|
||||
model = "gpt-3.5-turbo-0301"
|
||||
elif model == "gpt-4":
|
||||
# gpt-4 may change over time.
|
||||
# Returning num tokens assuming gpt-4-0314.
|
||||
model = "gpt-4-0314"
|
||||
if self.tiktoken_model_name is not None:
|
||||
model = self.tiktoken_model_name
|
||||
else:
|
||||
model = self.model_name
|
||||
if model == "gpt-3.5-turbo":
|
||||
# gpt-3.5-turbo may change over time.
|
||||
# Returning num tokens assuming gpt-3.5-turbo-0301.
|
||||
model = "gpt-3.5-turbo-0301"
|
||||
elif model == "gpt-4":
|
||||
# gpt-4 may change over time.
|
||||
# Returning num tokens assuming gpt-4-0314.
|
||||
model = "gpt-4-0314"
|
||||
# Returns the number of tokens used by a list of messages.
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model(model)
|
||||
|
@ -170,6 +170,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout in seconds for the OpenAPI request."""
|
||||
headers: Any = None
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -265,7 +275,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
tokens = []
|
||||
indices = []
|
||||
encoding = tiktoken.model.encoding_for_model(self.model)
|
||||
model_name = self.tiktoken_model_name or self.model
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, text in enumerate(texts):
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
@ -329,7 +345,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
tokens = []
|
||||
indices = []
|
||||
encoding = tiktoken.model.encoding_for_model(self.model)
|
||||
model_name = self.tiktoken_model_name or self.model
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, text in enumerate(texts):
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
|
@ -171,6 +171,16 @@ class BaseOpenAI(BaseLLM):
|
||||
"""Set of special tokens that are allowed。"""
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
|
||||
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
|
||||
"""Initialize the OpenAI object."""
|
||||
@ -491,7 +501,13 @@ class BaseOpenAI(BaseLLM):
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.model_name)
|
||||
model_name = self.tiktoken_model_name or self.model_name
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
enc = tiktoken.get_encoding(model)
|
||||
|
||||
return enc.encode(
|
||||
text,
|
||||
|
Loading…
Reference in New Issue
Block a user