mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
Add allowed and disallowed special arguments to BaseOpenAI (#3012)
## Background This PR fixes this error when there are special tokens when querying the chain: ``` Encountered text corresponding to disallowed special token '<|endofprompt|>'. If you want this text to be encoded as a special token, pass it to `allowed_special`, e.g. `allowed_special={'<|endofprompt|>', ...}`. If you want this text to be encoded as normal text, disable the check for this token by passing `disallowed_special=(enc.special_tokens_set - {'<|endofprompt|>'})`. To disable this check for all special tokens, pass `disallowed_special=()`. ``` Refer to the code snippet below, it breaks in the chain line. ``` chain = ConversationalRetrievalChain.from_llm( ChatOpenAI(openai_api_key=OPENAI_API_KEY), retriever=vectorstore.as_retriever(), qa_prompt=prompt, condense_question_prompt=condense_prompt, ) answer = chain({"question": f"{question}"}) ``` However `ChatOpenAI` class is not accepting `allowed_special` and `disallowed_special` at the moment so they cannot be passed to the `encode()` in `get_num_tokens` method to avoid the errors. ## Change - Add `allowed_special` and `disallowed_special` attributes to `BaseOpenAI` class. - Pass in `allowed_special` and `disallowed_special` as arguments of `encode()` in tiktoken. --------- Co-authored-by: samcarmen <“carmen.samkahman@gmail.com”>
This commit is contained in:
parent
9d23cfc7dd
commit
d54c88aa21
@ -5,11 +5,14 @@ import logging
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
@ -150,6 +153,10 @@ class BaseOpenAI(BaseLLM):
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
||||
"""Set of special tokens that are allowed。"""
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
|
||||
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
|
||||
"""Initialize the OpenAI object."""
|
||||
@ -449,7 +456,11 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.model_name)
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
tokenized_text = enc.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
@ -602,6 +613,10 @@ class OpenAIChat(BaseLLM):
|
||||
"""Series of messages for Chat input."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
||||
"""Set of special tokens that are allowed。"""
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -785,7 +800,11 @@ class OpenAIChat(BaseLLM):
|
||||
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
|
||||
# encode the text using the GPT-3.5-Turbo encoder
|
||||
tokenized_text = enc.encode(text)
|
||||
tokenized_text = enc.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
Loading…
Reference in New Issue
Block a user