From de4bf3dedf2c7cb7ba6c3044745bab3c3ef6352d Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Sat, 11 May 2024 15:13:25 +0800 Subject: [PATCH] [Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708) * Adapt repetition_penalty and no_repeat_ngram_size * fix no_repeat_ngram_size_logit_process * remove batch_updated * fix annotation * modified codes based on the review feedback. * rm get_batch_token_ids --- colossalai/inference/batch_bucket.py | 9 +++ colossalai/inference/config.py | 10 ++- colossalai/inference/core/engine.py | 6 +- colossalai/inference/core/request_handler.py | 15 ++-- colossalai/inference/logit_processors.py | 72 ++++++++++++++++++-- 5 files changed, 94 insertions(+), 18 deletions(-) diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 8cc9eebaa..f8571c0ca 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -102,6 +102,13 @@ class BatchBucket: def num_tokens_to_verify(self) -> int: return self._num_tokens_to_verify + @property + def batch_token_ids(self) -> List[List[int]]: + out = [] + for seq in self.seqs_li: + out.append(seq.input_token_id + seq.output_token_id) + return out + def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None: """Set batch bucket to use speculatvie decoding. This will notify the adjust the lengths of inputs during modeling, @@ -328,6 +335,7 @@ class BatchBucket: seqs.append(seq) if not self.is_compact: self._make_compact() + return seqs, block_tables def pop_finished( @@ -432,6 +440,7 @@ class BatchBucket: block_tables = torch.stack(block_tables_li) self.add_seqs(seqs, alloc_block_tables=block_tables) unmerged_ids = other.seqs_ids + return unmerged_ids ########## The following methods are expected to be used in modeling ########### diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index aae2024e0..8bd2394ad 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -99,7 +99,9 @@ class InferenceConfig: early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False. top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None. top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None. - min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None. + temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0. + repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. + no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences. n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. block_size (int): The number of blocks in a logical block, defaults to 16. @@ -136,7 +138,9 @@ class InferenceConfig: early_stopping: Optional[bool] = False top_k: Optional[int] = None top_p: Optional[float] = None - min_p: Optional[float] = None + temperature: Optional[float] = 1.0 + no_repeat_ngram_size: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 # speculative decoding configs max_n_spec_tokens: int = 5 @@ -213,7 +217,7 @@ class InferenceConfig: "do_sample": self.do_sample, "num_beams": self.beam_width, } - for type in ["top_k", "top_p", "min_p"]: + for type in ["repetition_penalty", "no_repeat_ngram_size", "temperature", "top_k", "top_p"]: if hasattr(self, type): meta_config[type] = getattr(self, type) for type in ["pad_token_id", "bos_token_id", "eos_token_id"]: diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 1ced54dd7..44f2c8f47 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -424,7 +424,7 @@ class InferenceEngine: # 2. Prefill main model (Verifier) - fill past kv cache for main model logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - next_tokens = self.request_handler.search_tokens(self.generation_config, logits) + next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch) # append new inputs to the batch, temporarily batch.append_batch_tokens(next_tokens) self.request_handler.allocate_batch_spec_dec(batch, 1) @@ -472,7 +472,7 @@ class InferenceEngine: input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - next_tokens = self.request_handler.search_tokens(self.generation_config, logits) + next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch) # 5. Compare and process the results diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) @@ -738,7 +738,7 @@ class InferenceEngine: logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) if self.inference_config.pad_input: logits = logits[:, -1, :] - next_tokens = self.request_handler.search_tokens(self.generation_config, logits) + next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch) self.request_handler.append_next_tokens(next_tokens) finished_sequences = self.request_handler.update() diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 03b4d2305..c514eeccf 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -11,12 +11,9 @@ from colossalai.inference.kv_cache import KVCacheManager from colossalai.inference.logit_processors import logit_processor from colossalai.inference.sampler import * from colossalai.inference.struct import RequestStatus, Sequence -from colossalai.logging import get_dist_logger __all__ = ["RunningList", "RequestHandler"] -logger = get_dist_logger(__name__) - class RunningList: """ @@ -331,15 +328,21 @@ class RequestHandler: def total_requests_in_batch_bucket(self) -> int: return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size - def search_tokens(self, generation_config: GenerationConfig, logits): + def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket): """ Sample tokens for finished requests. """ + # NOTE: need to decide the granularity to process logits (sequence or batch) + config_dict = generation_config.to_dict() + # process repetition_penalty, no_repeat_ngram_size + for type in ["repetition_penalty", "no_repeat_ngram_size"]: + if type in config_dict and config_dict[type] is not None: + logits = logit_processor(type, logits, config_dict[type], cur_batch) + # do logit processor if generation_config.do_sample: - # NOTE: need to decide the granularity to process logits (sequence or batch) - config_dict = generation_config.to_dict() + # process temperature, top_k, top_p for type in ["temperature", "top_k", "top_p"]: if type in config_dict and config_dict[type] is not None: logits = logit_processor(type, logits, config_dict[type]) diff --git a/colossalai/inference/logit_processors.py b/colossalai/inference/logit_processors.py index 39044fcec..b7119a221 100644 --- a/colossalai/inference/logit_processors.py +++ b/colossalai/inference/logit_processors.py @@ -1,6 +1,10 @@ +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py + import torch import torch.nn.functional as F +from colossalai.inference.batch_bucket import BatchBucket + _LOGIT_PROCESSOR_MAP = {} @@ -17,6 +21,66 @@ def register_logit_processor(process_type): return register +@register_logit_processor("no_repeat_ngram_size") +def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket): + """ + enforces no repetition of n-grams to avoid repetitions of word sequences. + """ + + if not isinstance(ngram_size, int) or ngram_size < 0: + raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.") + + if ngram_size != 0: + batch_token_ids = batch.batch_token_ids + batch_size = len(batch_token_ids) + + for batch_id in range(batch_size): + current_token_ids = batch_token_ids[batch_id] + current_len = len(current_token_ids) + if current_len + 1 < ngram_size: + continue + + ngrams_dict = {} + + for ngram in zip(*[current_token_ids[i:] for i in range(ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]] + + prev_ngrams = tuple(current_token_ids[current_len + 1 - ngram_size : current_len]) + banned_token = ngrams_dict.get(prev_ngrams, []) + + logits[batch_id, banned_token] = -float("inf") + + return logits + + +@register_logit_processor("repetition_penalty") +def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket): + """ + apply the penalty to the tokens present in the prompt. + """ + + if not isinstance(penalty, float) or not (penalty > 0): + raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.") + + logit_list = [] + + # TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels. + if penalty != 1.0: + batch_token_ids = batch.batch_token_ids + for batch_id in range(len(batch_token_ids)): + current_logit = logits[batch_id] + current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device) + + curretn_socre = torch.gather(current_logit, 0, current_token) + curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty) + logit_list.append(current_logit.scatter(0, current_token, curretn_socre)) + + logits = torch.stack(logit_list) + + return logits + + @register_logit_processor("temperature") def temperature_logit_process(logits, temperature: float): """ @@ -68,14 +132,13 @@ def top_p_logit_processor(logits, top_p: float): return logits -def logit_processor(processor: str, logits, attrs): +def logit_processor(processor: str, logits, *args, **kwargs): """ do logit process for given logits. Args: processor(str): the type of logit processor logits(torch.Tensor): input logits - attrs(dict): attrs of the logit processor Returns: logits after process @@ -84,8 +147,5 @@ def logit_processor(processor: str, logits, attrs): return logits else: func = _LOGIT_PROCESSOR_MAP[processor] - try: - logits = func(logits, attrs) - except Exception: - return logits + logits = func(logits, *args, **kwargs) return logits