[Inference] Fix Inference Generation Config and Sampling (#5710)

* refactor and add

* config default values

* fix gen config passing

* fix rpc generation config
This commit is contained in:
Yuanheng Zhao
2024-05-19 15:08:42 +08:00
committed by GitHub
parent 8bcfe360fd
commit 283c407a19
6 changed files with 124 additions and 68 deletions

View File

@@ -1,27 +1,28 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
from typing import List
import logging
from typing import List, Union
import torch
import torch.nn.functional as F
_LOGIT_PROCESSOR_MAP = {}
_LOGITS_PROCESSOR_MAP = {}
def register_logit_processor(process_type):
def register_logits_processor(process_type):
"""
register flops computation function for operation.
"""
def register(func):
global _LOGIT_PROCESSOR_MAP
_LOGIT_PROCESSOR_MAP[process_type] = func
global _LOGITS_PROCESSOR_MAP
_LOGITS_PROCESSOR_MAP[process_type] = func
return func
return register
@register_logit_processor("no_repeat_ngram_size")
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]):
@register_logits_processor("no_repeat_ngram_size")
def apply_no_repeat_ngram_size(logits, ngram_size: int, batch_token_ids: List[List[int]]):
"""
enforces no repetition of n-grams to avoid repetitions of word sequences.
"""
@@ -52,8 +53,8 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids:
return logits
@register_logit_processor("repetition_penalty")
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]):
@register_logits_processor("repetition_penalty")
def apply_repetition_penalty(logits, penalty: float, batch_token_ids: List[List[int]]):
"""
apply the penalty to the tokens present in the prompt.
"""
@@ -61,7 +62,7 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
logit_list = []
logits_list = []
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
if penalty != 1.0:
@@ -71,15 +72,15 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
curretn_socre = torch.gather(current_logit, 0, current_token)
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))
logits_list.append(current_logit.scatter(0, current_token, curretn_socre))
logits = torch.stack(logit_list)
logits = torch.stack(logits_list)
return logits
@register_logit_processor("temperature")
def temperature_logit_process(logits, temperature: float):
@register_logits_processor("temperature")
def apply_temperature(logits, temperature: float):
"""
apply temperature scaling.
"""
@@ -93,8 +94,8 @@ def temperature_logit_process(logits, temperature: float):
return logits if temperature == 1.0 else logits / temperature
@register_logit_processor("top_k")
def top_k_logit_processor(logits, top_k: int):
@register_logits_processor("top_k")
def apply_top_k(logits, top_k: int):
"""
top_k logit processor
"""
@@ -107,8 +108,8 @@ def top_k_logit_processor(logits, top_k: int):
return logits
@register_logit_processor("top_p")
def top_p_logit_processor(logits, top_p: float):
@register_logits_processor("top_p")
def apply_top_p(logits, top_p: float):
"""
top_p logit processor
"""
@@ -129,7 +130,46 @@ def top_p_logit_processor(logits, top_p: float):
return logits
def logit_processor(processor: str, logits, *args, **kwargs):
@register_logits_processor("forced_eos_token_id")
def apply_forced_eos_token_id(
logits: torch.Tensor,
sequence_lengths: Union[torch.Tensor, List[int]],
max_lengths: Union[torch.Tensor, List[int]],
eos_token_id: Union[int, List[int]],
):
"""
Enforces the specified token as the last generated token when the maximum output length
is reached. Notice that the maximum output lengths for different sequences, even if they're
in the same batch, can be different.
Args:
logits(torch.Tensor): logits
sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens
max_lengths(torch.Tensor): the maximum length for each sequence
eos_token_id(Union[int, List[int]]): forced eos token id
"""
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if isinstance(sequence_lengths, torch.Tensor):
sequence_lengths = sequence_lengths.tolist()
if isinstance(max_lengths, torch.Tensor):
max_lengths = max_lengths.tolist()
select_indexes = []
num_sequences = logits.shape[0]
sequence_lengths = sequence_lengths[:num_sequences]
max_lengths = max_lengths[:num_sequences]
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)):
if sequence_length == max_out_length - 1:
select_indexes.append(i)
if select_indexes:
logits[select_indexes, :] = -float("inf")
logits[select_indexes, eos_token_id] = 0
return logits
def get_logits_processor(processor: str, logits, *args, **kwargs):
"""
do logit process for given logits.
@@ -140,9 +180,10 @@ def logit_processor(processor: str, logits, *args, **kwargs):
Returns:
logits after process
"""
if processor not in _LOGIT_PROCESSOR_MAP:
return logits
if processor not in _LOGITS_PROCESSOR_MAP:
logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.")
else:
func = _LOGIT_PROCESSOR_MAP[processor]
func = _LOGITS_PROCESSOR_MAP[processor]
logits = func(logits, *args, **kwargs)
return logits
return logits