mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-15 06:33:16 +00:00
[Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708)
* Adapt repetition_penalty and no_repeat_ngram_size * fix no_repeat_ngram_size_logit_process * remove batch_updated * fix annotation * modified codes based on the review feedback. * rm get_batch_token_ids
This commit is contained in:
parent
50104ab340
commit
de4bf3dedf
@ -102,6 +102,13 @@ class BatchBucket:
|
|||||||
def num_tokens_to_verify(self) -> int:
|
def num_tokens_to_verify(self) -> int:
|
||||||
return self._num_tokens_to_verify
|
return self._num_tokens_to_verify
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_token_ids(self) -> List[List[int]]:
|
||||||
|
out = []
|
||||||
|
for seq in self.seqs_li:
|
||||||
|
out.append(seq.input_token_id + seq.output_token_id)
|
||||||
|
return out
|
||||||
|
|
||||||
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
|
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
|
||||||
"""Set batch bucket to use speculatvie decoding.
|
"""Set batch bucket to use speculatvie decoding.
|
||||||
This will notify the adjust the lengths of inputs during modeling,
|
This will notify the adjust the lengths of inputs during modeling,
|
||||||
@ -328,6 +335,7 @@ class BatchBucket:
|
|||||||
seqs.append(seq)
|
seqs.append(seq)
|
||||||
if not self.is_compact:
|
if not self.is_compact:
|
||||||
self._make_compact()
|
self._make_compact()
|
||||||
|
|
||||||
return seqs, block_tables
|
return seqs, block_tables
|
||||||
|
|
||||||
def pop_finished(
|
def pop_finished(
|
||||||
@ -432,6 +440,7 @@ class BatchBucket:
|
|||||||
block_tables = torch.stack(block_tables_li)
|
block_tables = torch.stack(block_tables_li)
|
||||||
self.add_seqs(seqs, alloc_block_tables=block_tables)
|
self.add_seqs(seqs, alloc_block_tables=block_tables)
|
||||||
unmerged_ids = other.seqs_ids
|
unmerged_ids = other.seqs_ids
|
||||||
|
|
||||||
return unmerged_ids
|
return unmerged_ids
|
||||||
|
|
||||||
########## The following methods are expected to be used in modeling ###########
|
########## The following methods are expected to be used in modeling ###########
|
||||||
|
@ -99,7 +99,9 @@ class InferenceConfig:
|
|||||||
early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
|
early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
|
||||||
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
||||||
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
||||||
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
|
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
|
||||||
|
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
||||||
|
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
|
||||||
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||||
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
||||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||||
@ -136,7 +138,9 @@ class InferenceConfig:
|
|||||||
early_stopping: Optional[bool] = False
|
early_stopping: Optional[bool] = False
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
min_p: Optional[float] = None
|
temperature: Optional[float] = 1.0
|
||||||
|
no_repeat_ngram_size: Optional[int] = 0
|
||||||
|
repetition_penalty: Optional[float] = 1.0
|
||||||
|
|
||||||
# speculative decoding configs
|
# speculative decoding configs
|
||||||
max_n_spec_tokens: int = 5
|
max_n_spec_tokens: int = 5
|
||||||
@ -213,7 +217,7 @@ class InferenceConfig:
|
|||||||
"do_sample": self.do_sample,
|
"do_sample": self.do_sample,
|
||||||
"num_beams": self.beam_width,
|
"num_beams": self.beam_width,
|
||||||
}
|
}
|
||||||
for type in ["top_k", "top_p", "min_p"]:
|
for type in ["repetition_penalty", "no_repeat_ngram_size", "temperature", "top_k", "top_p"]:
|
||||||
if hasattr(self, type):
|
if hasattr(self, type):
|
||||||
meta_config[type] = getattr(self, type)
|
meta_config[type] = getattr(self, type)
|
||||||
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:
|
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:
|
||||||
|
@ -424,7 +424,7 @@ class InferenceEngine:
|
|||||||
|
|
||||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
||||||
# append new inputs to the batch, temporarily
|
# append new inputs to the batch, temporarily
|
||||||
batch.append_batch_tokens(next_tokens)
|
batch.append_batch_tokens(next_tokens)
|
||||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||||
@ -472,7 +472,7 @@ class InferenceEngine:
|
|||||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
|
||||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
||||||
|
|
||||||
# 5. Compare and process the results
|
# 5. Compare and process the results
|
||||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||||
@ -738,7 +738,7 @@ class InferenceEngine:
|
|||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
||||||
self.request_handler.append_next_tokens(next_tokens)
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
@ -11,12 +11,9 @@ from colossalai.inference.kv_cache import KVCacheManager
|
|||||||
from colossalai.inference.logit_processors import logit_processor
|
from colossalai.inference.logit_processors import logit_processor
|
||||||
from colossalai.inference.sampler import *
|
from colossalai.inference.sampler import *
|
||||||
from colossalai.inference.struct import RequestStatus, Sequence
|
from colossalai.inference.struct import RequestStatus, Sequence
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
__all__ = ["RunningList", "RequestHandler"]
|
__all__ = ["RunningList", "RequestHandler"]
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RunningList:
|
class RunningList:
|
||||||
"""
|
"""
|
||||||
@ -331,15 +328,21 @@ class RequestHandler:
|
|||||||
def total_requests_in_batch_bucket(self) -> int:
|
def total_requests_in_batch_bucket(self) -> int:
|
||||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||||
|
|
||||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
|
||||||
"""
|
"""
|
||||||
Sample tokens for finished requests.
|
Sample tokens for finished requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||||
|
config_dict = generation_config.to_dict()
|
||||||
|
# process repetition_penalty, no_repeat_ngram_size
|
||||||
|
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||||
|
if type in config_dict and config_dict[type] is not None:
|
||||||
|
logits = logit_processor(type, logits, config_dict[type], cur_batch)
|
||||||
|
|
||||||
# do logit processor
|
# do logit processor
|
||||||
if generation_config.do_sample:
|
if generation_config.do_sample:
|
||||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
# process temperature, top_k, top_p
|
||||||
config_dict = generation_config.to_dict()
|
|
||||||
for type in ["temperature", "top_k", "top_p"]:
|
for type in ["temperature", "top_k", "top_p"]:
|
||||||
if type in config_dict and config_dict[type] is not None:
|
if type in config_dict and config_dict[type] is not None:
|
||||||
logits = logit_processor(type, logits, config_dict[type])
|
logits = logit_processor(type, logits, config_dict[type])
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
|
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
|
|
||||||
_LOGIT_PROCESSOR_MAP = {}
|
_LOGIT_PROCESSOR_MAP = {}
|
||||||
|
|
||||||
|
|
||||||
@ -17,6 +21,66 @@ def register_logit_processor(process_type):
|
|||||||
return register
|
return register
|
||||||
|
|
||||||
|
|
||||||
|
@register_logit_processor("no_repeat_ngram_size")
|
||||||
|
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket):
|
||||||
|
"""
|
||||||
|
enforces no repetition of n-grams to avoid repetitions of word sequences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(ngram_size, int) or ngram_size < 0:
|
||||||
|
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
|
||||||
|
|
||||||
|
if ngram_size != 0:
|
||||||
|
batch_token_ids = batch.batch_token_ids
|
||||||
|
batch_size = len(batch_token_ids)
|
||||||
|
|
||||||
|
for batch_id in range(batch_size):
|
||||||
|
current_token_ids = batch_token_ids[batch_id]
|
||||||
|
current_len = len(current_token_ids)
|
||||||
|
if current_len + 1 < ngram_size:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ngrams_dict = {}
|
||||||
|
|
||||||
|
for ngram in zip(*[current_token_ids[i:] for i in range(ngram_size)]):
|
||||||
|
prev_ngram_tuple = tuple(ngram[:-1])
|
||||||
|
ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]]
|
||||||
|
|
||||||
|
prev_ngrams = tuple(current_token_ids[current_len + 1 - ngram_size : current_len])
|
||||||
|
banned_token = ngrams_dict.get(prev_ngrams, [])
|
||||||
|
|
||||||
|
logits[batch_id, banned_token] = -float("inf")
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@register_logit_processor("repetition_penalty")
|
||||||
|
def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket):
|
||||||
|
"""
|
||||||
|
apply the penalty to the tokens present in the prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(penalty, float) or not (penalty > 0):
|
||||||
|
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
|
||||||
|
|
||||||
|
logit_list = []
|
||||||
|
|
||||||
|
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
|
||||||
|
if penalty != 1.0:
|
||||||
|
batch_token_ids = batch.batch_token_ids
|
||||||
|
for batch_id in range(len(batch_token_ids)):
|
||||||
|
current_logit = logits[batch_id]
|
||||||
|
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)
|
||||||
|
|
||||||
|
curretn_socre = torch.gather(current_logit, 0, current_token)
|
||||||
|
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
|
||||||
|
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))
|
||||||
|
|
||||||
|
logits = torch.stack(logit_list)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@register_logit_processor("temperature")
|
@register_logit_processor("temperature")
|
||||||
def temperature_logit_process(logits, temperature: float):
|
def temperature_logit_process(logits, temperature: float):
|
||||||
"""
|
"""
|
||||||
@ -68,14 +132,13 @@ def top_p_logit_processor(logits, top_p: float):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def logit_processor(processor: str, logits, attrs):
|
def logit_processor(processor: str, logits, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
do logit process for given logits.
|
do logit process for given logits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
processor(str): the type of logit processor
|
processor(str): the type of logit processor
|
||||||
logits(torch.Tensor): input logits
|
logits(torch.Tensor): input logits
|
||||||
attrs(dict): attrs of the logit processor
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logits after process
|
logits after process
|
||||||
@ -84,8 +147,5 @@ def logit_processor(processor: str, logits, attrs):
|
|||||||
return logits
|
return logits
|
||||||
else:
|
else:
|
||||||
func = _LOGIT_PROCESSOR_MAP[processor]
|
func = _LOGIT_PROCESSOR_MAP[processor]
|
||||||
try:
|
logits = func(logits, *args, **kwargs)
|
||||||
logits = func(logits, attrs)
|
|
||||||
except Exception:
|
|
||||||
return logits
|
|
||||||
return logits
|
return logits
|
||||||
|
Loading…
Reference in New Issue
Block a user