[Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708)

* Adapt repetition_penalty and no_repeat_ngram_size

* fix no_repeat_ngram_size_logit_process

* remove batch_updated

* fix annotation

* modified codes based on the review feedback.

* rm get_batch_token_ids
This commit is contained in:
yuehuayingxueluo 2024-05-11 15:13:25 +08:00 committed by GitHub
parent 50104ab340
commit de4bf3dedf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 94 additions and 18 deletions

View File

@ -102,6 +102,13 @@ class BatchBucket:
def num_tokens_to_verify(self) -> int: def num_tokens_to_verify(self) -> int:
return self._num_tokens_to_verify return self._num_tokens_to_verify
@property
def batch_token_ids(self) -> List[List[int]]:
out = []
for seq in self.seqs_li:
out.append(seq.input_token_id + seq.output_token_id)
return out
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None: def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding. """Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling, This will notify the adjust the lengths of inputs during modeling,
@ -328,6 +335,7 @@ class BatchBucket:
seqs.append(seq) seqs.append(seq)
if not self.is_compact: if not self.is_compact:
self._make_compact() self._make_compact()
return seqs, block_tables return seqs, block_tables
def pop_finished( def pop_finished(
@ -432,6 +440,7 @@ class BatchBucket:
block_tables = torch.stack(block_tables_li) block_tables = torch.stack(block_tables_li)
self.add_seqs(seqs, alloc_block_tables=block_tables) self.add_seqs(seqs, alloc_block_tables=block_tables)
unmerged_ids = other.seqs_ids unmerged_ids = other.seqs_ids
return unmerged_ids return unmerged_ids
########## The following methods are expected to be used in modeling ########### ########## The following methods are expected to be used in modeling ###########

View File

@ -99,7 +99,9 @@ class InferenceConfig:
early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False. early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None. top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None. top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None. temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16. block_size (int): The number of blocks in a logical block, defaults to 16.
@ -136,7 +138,9 @@ class InferenceConfig:
early_stopping: Optional[bool] = False early_stopping: Optional[bool] = False
top_k: Optional[int] = None top_k: Optional[int] = None
top_p: Optional[float] = None top_p: Optional[float] = None
min_p: Optional[float] = None temperature: Optional[float] = 1.0
no_repeat_ngram_size: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
# speculative decoding configs # speculative decoding configs
max_n_spec_tokens: int = 5 max_n_spec_tokens: int = 5
@ -213,7 +217,7 @@ class InferenceConfig:
"do_sample": self.do_sample, "do_sample": self.do_sample,
"num_beams": self.beam_width, "num_beams": self.beam_width,
} }
for type in ["top_k", "top_p", "min_p"]: for type in ["repetition_penalty", "no_repeat_ngram_size", "temperature", "top_k", "top_p"]:
if hasattr(self, type): if hasattr(self, type):
meta_config[type] = getattr(self, type) meta_config[type] = getattr(self, type)
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]: for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:

View File

@ -424,7 +424,7 @@ class InferenceEngine:
# 2. Prefill main model (Verifier) - fill past kv cache for main model # 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits) next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
# append new inputs to the batch, temporarily # append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens) batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1) self.request_handler.allocate_batch_spec_dec(batch, 1)
@ -472,7 +472,7 @@ class InferenceEngine:
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits) next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
# 5. Compare and process the results # 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
@ -738,7 +738,7 @@ class InferenceEngine:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits) next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
self.request_handler.append_next_tokens(next_tokens) self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()

View File

@ -11,12 +11,9 @@ from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import * from colossalai.inference.sampler import *
from colossalai.inference.struct import RequestStatus, Sequence from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.logging import get_dist_logger
__all__ = ["RunningList", "RequestHandler"] __all__ = ["RunningList", "RequestHandler"]
logger = get_dist_logger(__name__)
class RunningList: class RunningList:
""" """
@ -331,15 +328,21 @@ class RequestHandler:
def total_requests_in_batch_bucket(self) -> int: def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits): def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
""" """
Sample tokens for finished requests. Sample tokens for finished requests.
""" """
# do logit processor
if generation_config.do_sample:
# NOTE: need to decide the granularity to process logits (sequence or batch) # NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict() config_dict = generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type], cur_batch)
# do logit processor
if generation_config.do_sample:
# process temperature, top_k, top_p
for type in ["temperature", "top_k", "top_p"]: for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None: if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type]) logits = logit_processor(type, logits, config_dict[type])

View File

@ -1,6 +1,10 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.inference.batch_bucket import BatchBucket
_LOGIT_PROCESSOR_MAP = {} _LOGIT_PROCESSOR_MAP = {}
@ -17,6 +21,66 @@ def register_logit_processor(process_type):
return register return register
@register_logit_processor("no_repeat_ngram_size")
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket):
"""
enforces no repetition of n-grams to avoid repetitions of word sequences.
"""
if not isinstance(ngram_size, int) or ngram_size < 0:
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
if ngram_size != 0:
batch_token_ids = batch.batch_token_ids
batch_size = len(batch_token_ids)
for batch_id in range(batch_size):
current_token_ids = batch_token_ids[batch_id]
current_len = len(current_token_ids)
if current_len + 1 < ngram_size:
continue
ngrams_dict = {}
for ngram in zip(*[current_token_ids[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]]
prev_ngrams = tuple(current_token_ids[current_len + 1 - ngram_size : current_len])
banned_token = ngrams_dict.get(prev_ngrams, [])
logits[batch_id, banned_token] = -float("inf")
return logits
@register_logit_processor("repetition_penalty")
def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket):
"""
apply the penalty to the tokens present in the prompt.
"""
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
logit_list = []
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
if penalty != 1.0:
batch_token_ids = batch.batch_token_ids
for batch_id in range(len(batch_token_ids)):
current_logit = logits[batch_id]
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)
curretn_socre = torch.gather(current_logit, 0, current_token)
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))
logits = torch.stack(logit_list)
return logits
@register_logit_processor("temperature") @register_logit_processor("temperature")
def temperature_logit_process(logits, temperature: float): def temperature_logit_process(logits, temperature: float):
""" """
@ -68,14 +132,13 @@ def top_p_logit_processor(logits, top_p: float):
return logits return logits
def logit_processor(processor: str, logits, attrs): def logit_processor(processor: str, logits, *args, **kwargs):
""" """
do logit process for given logits. do logit process for given logits.
Args: Args:
processor(str): the type of logit processor processor(str): the type of logit processor
logits(torch.Tensor): input logits logits(torch.Tensor): input logits
attrs(dict): attrs of the logit processor
Returns: Returns:
logits after process logits after process
@ -84,8 +147,5 @@ def logit_processor(processor: str, logits, attrs):
return logits return logits
else: else:
func = _LOGIT_PROCESSOR_MAP[processor] func = _LOGIT_PROCESSOR_MAP[processor]
try: logits = func(logits, *args, **kwargs)
logits = func(logits, attrs)
except Exception:
return logits
return logits return logits