mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference] Fix Inference Generation Config and Sampling (#5710)
* refactor and add * config default values * fix gen config passing * fix rpc generation config
This commit is contained in:
@@ -1,27 +1,28 @@
|
||||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
|
||||
from typing import List
|
||||
import logging
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
_LOGIT_PROCESSOR_MAP = {}
|
||||
_LOGITS_PROCESSOR_MAP = {}
|
||||
|
||||
|
||||
def register_logit_processor(process_type):
|
||||
def register_logits_processor(process_type):
|
||||
"""
|
||||
register flops computation function for operation.
|
||||
"""
|
||||
|
||||
def register(func):
|
||||
global _LOGIT_PROCESSOR_MAP
|
||||
_LOGIT_PROCESSOR_MAP[process_type] = func
|
||||
global _LOGITS_PROCESSOR_MAP
|
||||
_LOGITS_PROCESSOR_MAP[process_type] = func
|
||||
return func
|
||||
|
||||
return register
|
||||
|
||||
|
||||
@register_logit_processor("no_repeat_ngram_size")
|
||||
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]):
|
||||
@register_logits_processor("no_repeat_ngram_size")
|
||||
def apply_no_repeat_ngram_size(logits, ngram_size: int, batch_token_ids: List[List[int]]):
|
||||
"""
|
||||
enforces no repetition of n-grams to avoid repetitions of word sequences.
|
||||
"""
|
||||
@@ -52,8 +53,8 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids:
|
||||
return logits
|
||||
|
||||
|
||||
@register_logit_processor("repetition_penalty")
|
||||
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]):
|
||||
@register_logits_processor("repetition_penalty")
|
||||
def apply_repetition_penalty(logits, penalty: float, batch_token_ids: List[List[int]]):
|
||||
"""
|
||||
apply the penalty to the tokens present in the prompt.
|
||||
"""
|
||||
@@ -61,7 +62,7 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
|
||||
|
||||
logit_list = []
|
||||
logits_list = []
|
||||
|
||||
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
|
||||
if penalty != 1.0:
|
||||
@@ -71,15 +72,15 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
|
||||
|
||||
curretn_socre = torch.gather(current_logit, 0, current_token)
|
||||
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
|
||||
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))
|
||||
logits_list.append(current_logit.scatter(0, current_token, curretn_socre))
|
||||
|
||||
logits = torch.stack(logit_list)
|
||||
logits = torch.stack(logits_list)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@register_logit_processor("temperature")
|
||||
def temperature_logit_process(logits, temperature: float):
|
||||
@register_logits_processor("temperature")
|
||||
def apply_temperature(logits, temperature: float):
|
||||
"""
|
||||
apply temperature scaling.
|
||||
"""
|
||||
@@ -93,8 +94,8 @@ def temperature_logit_process(logits, temperature: float):
|
||||
return logits if temperature == 1.0 else logits / temperature
|
||||
|
||||
|
||||
@register_logit_processor("top_k")
|
||||
def top_k_logit_processor(logits, top_k: int):
|
||||
@register_logits_processor("top_k")
|
||||
def apply_top_k(logits, top_k: int):
|
||||
"""
|
||||
top_k logit processor
|
||||
"""
|
||||
@@ -107,8 +108,8 @@ def top_k_logit_processor(logits, top_k: int):
|
||||
return logits
|
||||
|
||||
|
||||
@register_logit_processor("top_p")
|
||||
def top_p_logit_processor(logits, top_p: float):
|
||||
@register_logits_processor("top_p")
|
||||
def apply_top_p(logits, top_p: float):
|
||||
"""
|
||||
top_p logit processor
|
||||
"""
|
||||
@@ -129,7 +130,46 @@ def top_p_logit_processor(logits, top_p: float):
|
||||
return logits
|
||||
|
||||
|
||||
def logit_processor(processor: str, logits, *args, **kwargs):
|
||||
@register_logits_processor("forced_eos_token_id")
|
||||
def apply_forced_eos_token_id(
|
||||
logits: torch.Tensor,
|
||||
sequence_lengths: Union[torch.Tensor, List[int]],
|
||||
max_lengths: Union[torch.Tensor, List[int]],
|
||||
eos_token_id: Union[int, List[int]],
|
||||
):
|
||||
"""
|
||||
Enforces the specified token as the last generated token when the maximum output length
|
||||
is reached. Notice that the maximum output lengths for different sequences, even if they're
|
||||
in the same batch, can be different.
|
||||
|
||||
Args:
|
||||
logits(torch.Tensor): logits
|
||||
sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens
|
||||
max_lengths(torch.Tensor): the maximum length for each sequence
|
||||
eos_token_id(Union[int, List[int]]): forced eos token id
|
||||
"""
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if isinstance(sequence_lengths, torch.Tensor):
|
||||
sequence_lengths = sequence_lengths.tolist()
|
||||
if isinstance(max_lengths, torch.Tensor):
|
||||
max_lengths = max_lengths.tolist()
|
||||
|
||||
select_indexes = []
|
||||
num_sequences = logits.shape[0]
|
||||
sequence_lengths = sequence_lengths[:num_sequences]
|
||||
max_lengths = max_lengths[:num_sequences]
|
||||
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)):
|
||||
if sequence_length == max_out_length - 1:
|
||||
select_indexes.append(i)
|
||||
if select_indexes:
|
||||
logits[select_indexes, :] = -float("inf")
|
||||
logits[select_indexes, eos_token_id] = 0
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def get_logits_processor(processor: str, logits, *args, **kwargs):
|
||||
"""
|
||||
do logit process for given logits.
|
||||
|
||||
@@ -140,9 +180,10 @@ def logit_processor(processor: str, logits, *args, **kwargs):
|
||||
Returns:
|
||||
logits after process
|
||||
"""
|
||||
if processor not in _LOGIT_PROCESSOR_MAP:
|
||||
return logits
|
||||
if processor not in _LOGITS_PROCESSOR_MAP:
|
||||
logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.")
|
||||
else:
|
||||
func = _LOGIT_PROCESSOR_MAP[processor]
|
||||
func = _LOGITS_PROCESSOR_MAP[processor]
|
||||
logits = func(logits, *args, **kwargs)
|
||||
return logits
|
||||
|
||||
return logits
|
||||
|
Reference in New Issue
Block a user