mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708)
* Adapt repetition_penalty and no_repeat_ngram_size * fix no_repeat_ngram_size_logit_process * remove batch_updated * fix annotation * modified codes based on the review feedback. * rm get_batch_token_ids
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
|
||||
_LOGIT_PROCESSOR_MAP = {}
|
||||
|
||||
|
||||
@@ -17,6 +21,66 @@ def register_logit_processor(process_type):
|
||||
return register
|
||||
|
||||
|
||||
@register_logit_processor("no_repeat_ngram_size")
|
||||
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket):
|
||||
"""
|
||||
enforces no repetition of n-grams to avoid repetitions of word sequences.
|
||||
"""
|
||||
|
||||
if not isinstance(ngram_size, int) or ngram_size < 0:
|
||||
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
|
||||
|
||||
if ngram_size != 0:
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
batch_size = len(batch_token_ids)
|
||||
|
||||
for batch_id in range(batch_size):
|
||||
current_token_ids = batch_token_ids[batch_id]
|
||||
current_len = len(current_token_ids)
|
||||
if current_len + 1 < ngram_size:
|
||||
continue
|
||||
|
||||
ngrams_dict = {}
|
||||
|
||||
for ngram in zip(*[current_token_ids[i:] for i in range(ngram_size)]):
|
||||
prev_ngram_tuple = tuple(ngram[:-1])
|
||||
ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]]
|
||||
|
||||
prev_ngrams = tuple(current_token_ids[current_len + 1 - ngram_size : current_len])
|
||||
banned_token = ngrams_dict.get(prev_ngrams, [])
|
||||
|
||||
logits[batch_id, banned_token] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@register_logit_processor("repetition_penalty")
|
||||
def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket):
|
||||
"""
|
||||
apply the penalty to the tokens present in the prompt.
|
||||
"""
|
||||
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
|
||||
|
||||
logit_list = []
|
||||
|
||||
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
|
||||
if penalty != 1.0:
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
for batch_id in range(len(batch_token_ids)):
|
||||
current_logit = logits[batch_id]
|
||||
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)
|
||||
|
||||
curretn_socre = torch.gather(current_logit, 0, current_token)
|
||||
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
|
||||
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))
|
||||
|
||||
logits = torch.stack(logit_list)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@register_logit_processor("temperature")
|
||||
def temperature_logit_process(logits, temperature: float):
|
||||
"""
|
||||
@@ -68,14 +132,13 @@ def top_p_logit_processor(logits, top_p: float):
|
||||
return logits
|
||||
|
||||
|
||||
def logit_processor(processor: str, logits, attrs):
|
||||
def logit_processor(processor: str, logits, *args, **kwargs):
|
||||
"""
|
||||
do logit process for given logits.
|
||||
|
||||
Args:
|
||||
processor(str): the type of logit processor
|
||||
logits(torch.Tensor): input logits
|
||||
attrs(dict): attrs of the logit processor
|
||||
|
||||
Returns:
|
||||
logits after process
|
||||
@@ -84,8 +147,5 @@ def logit_processor(processor: str, logits, attrs):
|
||||
return logits
|
||||
else:
|
||||
func = _LOGIT_PROCESSOR_MAP[processor]
|
||||
try:
|
||||
logits = func(logits, attrs)
|
||||
except Exception:
|
||||
return logits
|
||||
logits = func(logits, *args, **kwargs)
|
||||
return logits
|
||||
|
Reference in New Issue
Block a user