[Inference] Fix Inference Generation Config and Sampling (#5710)

* refactor and add

* config default values

* fix gen config passing

* fix rpc generation config
This commit is contained in:
Yuanheng Zhao 2024-05-19 15:08:42 +08:00 committed by GitHub
parent 8bcfe360fd
commit 283c407a19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 124 additions and 68 deletions

View File

@ -202,11 +202,12 @@ class InferenceConfig(RPC_PARAM):
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio ] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
pad_input: bool = False pad_input: bool = False
early_stopping: Optional[bool] = False early_stopping: Optional[bool] = False
top_k: Optional[int] = None top_k: Optional[int] = 50
top_p: Optional[float] = None top_p: Optional[float] = 1.0
temperature: Optional[float] = 1.0 temperature: Optional[float] = 1.0
no_repeat_ngram_size: Optional[int] = 0 no_repeat_ngram_size: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
forced_eos_token_id: int = None
# speculative decoding configs # speculative decoding configs
max_n_spec_tokens: int = 5 max_n_spec_tokens: int = 5

View File

@ -76,6 +76,7 @@ class InferenceEngine:
self.init_model(model_or_path, model_policy) self.init_model(model_or_path, model_policy)
self.generation_config = inference_config.to_generation_config(self.model_config) self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
@ -524,12 +525,13 @@ class InferenceEngine:
Returns: Returns:
List[str]: Inference result returned by one generation. List[str]: Inference result returned by one generation.
""" """
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode(): with torch.inference_mode():
if isinstance(prompts, str) and isinstance(request_ids, int):
prompts = [prompts]
request_ids = [request_ids]
if prompts is not None or prompts_token_ids is not None: if prompts is not None or prompts_token_ids is not None:
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request( self.add_request(
request_ids=request_ids, request_ids=request_ids,
prompts=prompts, prompts=prompts,
@ -543,6 +545,7 @@ class InferenceEngine:
# intuition: If user provide a generation config, we should replace the existing one. # intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None: if generation_config is not None:
self.generation_config = generation_config self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
if self.use_spec_dec: if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized." assert self.drafter is not None, "Drafter Model is not initialized."
@ -688,11 +691,12 @@ class InferenceEngine:
) )
batch_token_ids = None batch_token_ids = None
config_dict = self.generation_config.to_dict() if (
# process repetition_penalty, no_repeat_ngram_size self.generation_config.repetition_penalty != 1.0
for type in ["repetition_penalty", "no_repeat_ngram_size"]: or self.generation_config.no_repeat_ngram_size > 0
if type in config_dict and config_dict[type] is not None: or self.generation_config.forced_eos_token_id is not None
batch_token_ids = batch.batch_token_ids ):
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference # only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False use_cuda_graph = False

View File

@ -257,7 +257,12 @@ class RPCInferenceEngine(InferenceEngine):
assert len(self.workers) == self.tp_size, "init workers first" assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [ init_tasks = [
self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param()) self.async_parallel_wrapper(
worker.execute_model_forward,
input_token_ids,
input_meta_data.to_rpc_param(),
self.generation_config_dict,
)
for worker in self.workers for worker in self.workers
] ]
ret = await asyncio.gather(*init_tasks) ret = await asyncio.gather(*init_tasks)

View File

@ -97,7 +97,9 @@ class rpcWorkerService(rpyc.Service):
) )
logger.info("physical cache init over") logger.info("physical cache init over")
def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict): def exposed_execute_model_forward(
self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict
):
# prepare the data for model forward # prepare the data for model forward
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor input_meta_data.fd_inter_tensor = self.fd_inter_tensor
@ -120,7 +122,7 @@ class rpcWorkerService(rpyc.Service):
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
next_tokens = search_tokens( next_tokens = search_tokens(
self.inference_config.to_generation_config(self.model_config), generation_config_param,
logits, logits,
input_meta_data.is_prompts, input_meta_data.is_prompts,
input_meta_data.batch_token_ids, input_meta_data.batch_token_ids,

View File

@ -1,27 +1,28 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
from typing import List import logging
from typing import List, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
_LOGIT_PROCESSOR_MAP = {} _LOGITS_PROCESSOR_MAP = {}
def register_logit_processor(process_type): def register_logits_processor(process_type):
""" """
register flops computation function for operation. register flops computation function for operation.
""" """
def register(func): def register(func):
global _LOGIT_PROCESSOR_MAP global _LOGITS_PROCESSOR_MAP
_LOGIT_PROCESSOR_MAP[process_type] = func _LOGITS_PROCESSOR_MAP[process_type] = func
return func return func
return register return register
@register_logit_processor("no_repeat_ngram_size") @register_logits_processor("no_repeat_ngram_size")
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]): def apply_no_repeat_ngram_size(logits, ngram_size: int, batch_token_ids: List[List[int]]):
""" """
enforces no repetition of n-grams to avoid repetitions of word sequences. enforces no repetition of n-grams to avoid repetitions of word sequences.
""" """
@ -52,8 +53,8 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids:
return logits return logits
@register_logit_processor("repetition_penalty") @register_logits_processor("repetition_penalty")
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]): def apply_repetition_penalty(logits, penalty: float, batch_token_ids: List[List[int]]):
""" """
apply the penalty to the tokens present in the prompt. apply the penalty to the tokens present in the prompt.
""" """
@ -61,7 +62,7 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
if not isinstance(penalty, float) or not (penalty > 0): if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.") raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
logit_list = [] logits_list = []
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels. # TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
if penalty != 1.0: if penalty != 1.0:
@ -71,15 +72,15 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
curretn_socre = torch.gather(current_logit, 0, current_token) curretn_socre = torch.gather(current_logit, 0, current_token)
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty) curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
logit_list.append(current_logit.scatter(0, current_token, curretn_socre)) logits_list.append(current_logit.scatter(0, current_token, curretn_socre))
logits = torch.stack(logit_list) logits = torch.stack(logits_list)
return logits return logits
@register_logit_processor("temperature") @register_logits_processor("temperature")
def temperature_logit_process(logits, temperature: float): def apply_temperature(logits, temperature: float):
""" """
apply temperature scaling. apply temperature scaling.
""" """
@ -93,8 +94,8 @@ def temperature_logit_process(logits, temperature: float):
return logits if temperature == 1.0 else logits / temperature return logits if temperature == 1.0 else logits / temperature
@register_logit_processor("top_k") @register_logits_processor("top_k")
def top_k_logit_processor(logits, top_k: int): def apply_top_k(logits, top_k: int):
""" """
top_k logit processor top_k logit processor
""" """
@ -107,8 +108,8 @@ def top_k_logit_processor(logits, top_k: int):
return logits return logits
@register_logit_processor("top_p") @register_logits_processor("top_p")
def top_p_logit_processor(logits, top_p: float): def apply_top_p(logits, top_p: float):
""" """
top_p logit processor top_p logit processor
""" """
@ -129,7 +130,46 @@ def top_p_logit_processor(logits, top_p: float):
return logits return logits
def logit_processor(processor: str, logits, *args, **kwargs): @register_logits_processor("forced_eos_token_id")
def apply_forced_eos_token_id(
logits: torch.Tensor,
sequence_lengths: Union[torch.Tensor, List[int]],
max_lengths: Union[torch.Tensor, List[int]],
eos_token_id: Union[int, List[int]],
):
"""
Enforces the specified token as the last generated token when the maximum output length
is reached. Notice that the maximum output lengths for different sequences, even if they're
in the same batch, can be different.
Args:
logits(torch.Tensor): logits
sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens
max_lengths(torch.Tensor): the maximum length for each sequence
eos_token_id(Union[int, List[int]]): forced eos token id
"""
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if isinstance(sequence_lengths, torch.Tensor):
sequence_lengths = sequence_lengths.tolist()
if isinstance(max_lengths, torch.Tensor):
max_lengths = max_lengths.tolist()
select_indexes = []
num_sequences = logits.shape[0]
sequence_lengths = sequence_lengths[:num_sequences]
max_lengths = max_lengths[:num_sequences]
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)):
if sequence_length == max_out_length - 1:
select_indexes.append(i)
if select_indexes:
logits[select_indexes, :] = -float("inf")
logits[select_indexes, eos_token_id] = 0
return logits
def get_logits_processor(processor: str, logits, *args, **kwargs):
""" """
do logit process for given logits. do logit process for given logits.
@ -140,9 +180,10 @@ def logit_processor(processor: str, logits, *args, **kwargs):
Returns: Returns:
logits after process logits after process
""" """
if processor not in _LOGIT_PROCESSOR_MAP: if processor not in _LOGITS_PROCESSOR_MAP:
return logits logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.")
else: else:
func = _LOGIT_PROCESSOR_MAP[processor] func = _LOGITS_PROCESSOR_MAP[processor]
logits = func(logits, *args, **kwargs) logits = func(logits, *args, **kwargs)
return logits
return logits

View File

@ -1,13 +1,12 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from colossalai.inference.logit_processors import logit_processor from colossalai.inference.logit_processors import get_logits_processor
def greedy_sample( def greedy_sample(
generation_config,
logprobs: torch.Tensor, logprobs: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -18,7 +17,6 @@ def greedy_sample(
def multinomial_sample( def multinomial_sample(
generation_config,
probs: torch.Tensor, probs: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -29,7 +27,7 @@ def multinomial_sample(
def beam_search_sample( def beam_search_sample(
generation_config, beam_width: int,
logprobs: torch.Tensor, logprobs: torch.Tensor,
is_prompt: bool = False, is_prompt: bool = False,
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
@ -46,7 +44,6 @@ def beam_search_sample(
# NOTE: this beam search sample function is wrong now. # NOTE: this beam search sample function is wrong now.
""" """
beam_width = generation_config.num_beams
results = [] results = []
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
@ -64,20 +61,8 @@ def beam_search_sample(
return results return results
def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = multinomial_sample(generation_config, probs)
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt)
return sample_tokens
def search_tokens( def search_tokens(
generation_config: GenerationConfig, generation_config: Union[GenerationConfig, dict],
logits, logits,
is_prompt: bool = False, is_prompt: bool = False,
batch_token_ids: Optional[List[List[int]]] = None, batch_token_ids: Optional[List[List[int]]] = None,
@ -86,23 +71,41 @@ def search_tokens(
Sample tokens for finished requests. Sample tokens for finished requests.
""" """
# NOTE: need to decide the granularity to process logits (sequence or batch) # NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type], batch_token_ids)
# do logit processor # convert GenerationConfig to dict
if generation_config.do_sample: # temporary fix for compatibility with the usage of RPCInferenceEngine
# process temperature, top_k, top_p if isinstance(generation_config, GenerationConfig):
for type in ["temperature", "top_k", "top_p"]: generation_config = generation_config.to_dict()
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type]) if (repetition_penalty := generation_config.get("repetition_penalty", 1.0)) != 1.0:
logits = get_logits_processor("repetition_penalty", logits, repetition_penalty, batch_token_ids)
if (no_repeat_ngram_size := generation_config.get("no_repeat_ngram_size", 0)) > 0:
logits = get_logits_processor("no_repeat_ngram_size", logits, no_repeat_ngram_size, batch_token_ids)
if (forced_eos_token_id := generation_config.get("forced_eos_token_id", None)) is not None:
sequence_lengths = [len(batch_token_ids[i]) for i in range(len(batch_token_ids))]
max_out_lengths = [generation_config.max_length for _ in range(len(batch_token_ids))]
logits = get_logits_processor(
"forced_eos_token_id", logits, sequence_lengths, max_out_lengths, forced_eos_token_id
)
if generation_config.get("do_sample"):
if (temperature := generation_config.get("temperature", 1.0)) != 1.0:
logits = get_logits_processor("temperature", logits, temperature)
if (top_k := generation_config.get("top_k", 0)) != 0:
logits = get_logits_processor("top_k", logits, top_k)
if (top_p := generation_config.get("top_p", 1.0)) < 1.0:
logits = get_logits_processor("top_p", logits, top_p)
# calculate probs # calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# sample the next tokens # sample the next tokens
sample_tokens = _sample(probs, logprobs, generation_config, is_prompt) if generation_config.get("num_beams", 1) != 1:
raise NotImplementedError("Beam search is not supported yet.")
if generation_config.get("do_sample", False):
sample_tokens = multinomial_sample(probs)
else:
sample_tokens = greedy_sample(logprobs)
return sample_tokens return sample_tokens