mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
[Inference] Fix Inference Generation Config and Sampling (#5710)
* refactor and add * config default values * fix gen config passing * fix rpc generation config
This commit is contained in:
parent
8bcfe360fd
commit
283c407a19
@ -202,11 +202,12 @@ class InferenceConfig(RPC_PARAM):
|
|||||||
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||||
pad_input: bool = False
|
pad_input: bool = False
|
||||||
early_stopping: Optional[bool] = False
|
early_stopping: Optional[bool] = False
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = 50
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = 1.0
|
||||||
temperature: Optional[float] = 1.0
|
temperature: Optional[float] = 1.0
|
||||||
no_repeat_ngram_size: Optional[int] = 0
|
no_repeat_ngram_size: Optional[int] = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: Optional[float] = 1.0
|
||||||
|
forced_eos_token_id: int = None
|
||||||
|
|
||||||
# speculative decoding configs
|
# speculative decoding configs
|
||||||
max_n_spec_tokens: int = 5
|
max_n_spec_tokens: int = 5
|
||||||
|
@ -76,6 +76,7 @@ class InferenceEngine:
|
|||||||
self.init_model(model_or_path, model_policy)
|
self.init_model(model_or_path, model_policy)
|
||||||
|
|
||||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||||
|
self.generation_config_dict = self.generation_config.to_dict()
|
||||||
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
@ -524,12 +525,13 @@ class InferenceEngine:
|
|||||||
Returns:
|
Returns:
|
||||||
List[str]: Inference result returned by one generation.
|
List[str]: Inference result returned by one generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||||
|
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||||
|
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
|
||||||
prompts = [prompts]
|
|
||||||
request_ids = [request_ids]
|
|
||||||
if prompts is not None or prompts_token_ids is not None:
|
if prompts is not None or prompts_token_ids is not None:
|
||||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
|
||||||
self.add_request(
|
self.add_request(
|
||||||
request_ids=request_ids,
|
request_ids=request_ids,
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
@ -543,6 +545,7 @@ class InferenceEngine:
|
|||||||
# intuition: If user provide a generation config, we should replace the existing one.
|
# intuition: If user provide a generation config, we should replace the existing one.
|
||||||
if generation_config is not None:
|
if generation_config is not None:
|
||||||
self.generation_config = generation_config
|
self.generation_config = generation_config
|
||||||
|
self.generation_config_dict = gen_config_dict
|
||||||
|
|
||||||
if self.use_spec_dec:
|
if self.use_spec_dec:
|
||||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||||
@ -688,11 +691,12 @@ class InferenceEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
batch_token_ids = None
|
batch_token_ids = None
|
||||||
config_dict = self.generation_config.to_dict()
|
if (
|
||||||
# process repetition_penalty, no_repeat_ngram_size
|
self.generation_config.repetition_penalty != 1.0
|
||||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
or self.generation_config.no_repeat_ngram_size > 0
|
||||||
if type in config_dict and config_dict[type] is not None:
|
or self.generation_config.forced_eos_token_id is not None
|
||||||
batch_token_ids = batch.batch_token_ids
|
):
|
||||||
|
batch_token_ids = batch.batch_token_ids
|
||||||
|
|
||||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||||
use_cuda_graph = False
|
use_cuda_graph = False
|
||||||
|
@ -257,7 +257,12 @@ class RPCInferenceEngine(InferenceEngine):
|
|||||||
assert len(self.workers) == self.tp_size, "init workers first"
|
assert len(self.workers) == self.tp_size, "init workers first"
|
||||||
|
|
||||||
init_tasks = [
|
init_tasks = [
|
||||||
self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param())
|
self.async_parallel_wrapper(
|
||||||
|
worker.execute_model_forward,
|
||||||
|
input_token_ids,
|
||||||
|
input_meta_data.to_rpc_param(),
|
||||||
|
self.generation_config_dict,
|
||||||
|
)
|
||||||
for worker in self.workers
|
for worker in self.workers
|
||||||
]
|
]
|
||||||
ret = await asyncio.gather(*init_tasks)
|
ret = await asyncio.gather(*init_tasks)
|
||||||
|
@ -97,7 +97,9 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
)
|
)
|
||||||
logger.info("physical cache init over")
|
logger.info("physical cache init over")
|
||||||
|
|
||||||
def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict):
|
def exposed_execute_model_forward(
|
||||||
|
self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict
|
||||||
|
):
|
||||||
# prepare the data for model forward
|
# prepare the data for model forward
|
||||||
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
||||||
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
||||||
@ -120,7 +122,7 @@ class rpcWorkerService(rpyc.Service):
|
|||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
next_tokens = search_tokens(
|
next_tokens = search_tokens(
|
||||||
self.inference_config.to_generation_config(self.model_config),
|
generation_config_param,
|
||||||
logits,
|
logits,
|
||||||
input_meta_data.is_prompts,
|
input_meta_data.is_prompts,
|
||||||
input_meta_data.batch_token_ids,
|
input_meta_data.batch_token_ids,
|
||||||
|
@ -1,27 +1,28 @@
|
|||||||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
|
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
|
||||||
from typing import List
|
import logging
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
_LOGIT_PROCESSOR_MAP = {}
|
_LOGITS_PROCESSOR_MAP = {}
|
||||||
|
|
||||||
|
|
||||||
def register_logit_processor(process_type):
|
def register_logits_processor(process_type):
|
||||||
"""
|
"""
|
||||||
register flops computation function for operation.
|
register flops computation function for operation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def register(func):
|
def register(func):
|
||||||
global _LOGIT_PROCESSOR_MAP
|
global _LOGITS_PROCESSOR_MAP
|
||||||
_LOGIT_PROCESSOR_MAP[process_type] = func
|
_LOGITS_PROCESSOR_MAP[process_type] = func
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return register
|
return register
|
||||||
|
|
||||||
|
|
||||||
@register_logit_processor("no_repeat_ngram_size")
|
@register_logits_processor("no_repeat_ngram_size")
|
||||||
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]):
|
def apply_no_repeat_ngram_size(logits, ngram_size: int, batch_token_ids: List[List[int]]):
|
||||||
"""
|
"""
|
||||||
enforces no repetition of n-grams to avoid repetitions of word sequences.
|
enforces no repetition of n-grams to avoid repetitions of word sequences.
|
||||||
"""
|
"""
|
||||||
@ -52,8 +53,8 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids:
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@register_logit_processor("repetition_penalty")
|
@register_logits_processor("repetition_penalty")
|
||||||
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]):
|
def apply_repetition_penalty(logits, penalty: float, batch_token_ids: List[List[int]]):
|
||||||
"""
|
"""
|
||||||
apply the penalty to the tokens present in the prompt.
|
apply the penalty to the tokens present in the prompt.
|
||||||
"""
|
"""
|
||||||
@ -61,7 +62,7 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
|
|||||||
if not isinstance(penalty, float) or not (penalty > 0):
|
if not isinstance(penalty, float) or not (penalty > 0):
|
||||||
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
|
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
|
||||||
|
|
||||||
logit_list = []
|
logits_list = []
|
||||||
|
|
||||||
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
|
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
|
||||||
if penalty != 1.0:
|
if penalty != 1.0:
|
||||||
@ -71,15 +72,15 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
|
|||||||
|
|
||||||
curretn_socre = torch.gather(current_logit, 0, current_token)
|
curretn_socre = torch.gather(current_logit, 0, current_token)
|
||||||
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
|
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
|
||||||
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))
|
logits_list.append(current_logit.scatter(0, current_token, curretn_socre))
|
||||||
|
|
||||||
logits = torch.stack(logit_list)
|
logits = torch.stack(logits_list)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@register_logit_processor("temperature")
|
@register_logits_processor("temperature")
|
||||||
def temperature_logit_process(logits, temperature: float):
|
def apply_temperature(logits, temperature: float):
|
||||||
"""
|
"""
|
||||||
apply temperature scaling.
|
apply temperature scaling.
|
||||||
"""
|
"""
|
||||||
@ -93,8 +94,8 @@ def temperature_logit_process(logits, temperature: float):
|
|||||||
return logits if temperature == 1.0 else logits / temperature
|
return logits if temperature == 1.0 else logits / temperature
|
||||||
|
|
||||||
|
|
||||||
@register_logit_processor("top_k")
|
@register_logits_processor("top_k")
|
||||||
def top_k_logit_processor(logits, top_k: int):
|
def apply_top_k(logits, top_k: int):
|
||||||
"""
|
"""
|
||||||
top_k logit processor
|
top_k logit processor
|
||||||
"""
|
"""
|
||||||
@ -107,8 +108,8 @@ def top_k_logit_processor(logits, top_k: int):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@register_logit_processor("top_p")
|
@register_logits_processor("top_p")
|
||||||
def top_p_logit_processor(logits, top_p: float):
|
def apply_top_p(logits, top_p: float):
|
||||||
"""
|
"""
|
||||||
top_p logit processor
|
top_p logit processor
|
||||||
"""
|
"""
|
||||||
@ -129,7 +130,46 @@ def top_p_logit_processor(logits, top_p: float):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def logit_processor(processor: str, logits, *args, **kwargs):
|
@register_logits_processor("forced_eos_token_id")
|
||||||
|
def apply_forced_eos_token_id(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sequence_lengths: Union[torch.Tensor, List[int]],
|
||||||
|
max_lengths: Union[torch.Tensor, List[int]],
|
||||||
|
eos_token_id: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Enforces the specified token as the last generated token when the maximum output length
|
||||||
|
is reached. Notice that the maximum output lengths for different sequences, even if they're
|
||||||
|
in the same batch, can be different.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits(torch.Tensor): logits
|
||||||
|
sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens
|
||||||
|
max_lengths(torch.Tensor): the maximum length for each sequence
|
||||||
|
eos_token_id(Union[int, List[int]]): forced eos token id
|
||||||
|
"""
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
if isinstance(sequence_lengths, torch.Tensor):
|
||||||
|
sequence_lengths = sequence_lengths.tolist()
|
||||||
|
if isinstance(max_lengths, torch.Tensor):
|
||||||
|
max_lengths = max_lengths.tolist()
|
||||||
|
|
||||||
|
select_indexes = []
|
||||||
|
num_sequences = logits.shape[0]
|
||||||
|
sequence_lengths = sequence_lengths[:num_sequences]
|
||||||
|
max_lengths = max_lengths[:num_sequences]
|
||||||
|
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)):
|
||||||
|
if sequence_length == max_out_length - 1:
|
||||||
|
select_indexes.append(i)
|
||||||
|
if select_indexes:
|
||||||
|
logits[select_indexes, :] = -float("inf")
|
||||||
|
logits[select_indexes, eos_token_id] = 0
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_processor(processor: str, logits, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
do logit process for given logits.
|
do logit process for given logits.
|
||||||
|
|
||||||
@ -140,9 +180,10 @@ def logit_processor(processor: str, logits, *args, **kwargs):
|
|||||||
Returns:
|
Returns:
|
||||||
logits after process
|
logits after process
|
||||||
"""
|
"""
|
||||||
if processor not in _LOGIT_PROCESSOR_MAP:
|
if processor not in _LOGITS_PROCESSOR_MAP:
|
||||||
return logits
|
logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.")
|
||||||
else:
|
else:
|
||||||
func = _LOGIT_PROCESSOR_MAP[processor]
|
func = _LOGITS_PROCESSOR_MAP[processor]
|
||||||
logits = func(logits, *args, **kwargs)
|
logits = func(logits, *args, **kwargs)
|
||||||
return logits
|
|
||||||
|
return logits
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
from colossalai.inference.logit_processors import logit_processor
|
from colossalai.inference.logit_processors import get_logits_processor
|
||||||
|
|
||||||
|
|
||||||
def greedy_sample(
|
def greedy_sample(
|
||||||
generation_config,
|
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -18,7 +17,6 @@ def greedy_sample(
|
|||||||
|
|
||||||
|
|
||||||
def multinomial_sample(
|
def multinomial_sample(
|
||||||
generation_config,
|
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -29,7 +27,7 @@ def multinomial_sample(
|
|||||||
|
|
||||||
|
|
||||||
def beam_search_sample(
|
def beam_search_sample(
|
||||||
generation_config,
|
beam_width: int,
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
is_prompt: bool = False,
|
is_prompt: bool = False,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
@ -46,7 +44,6 @@ def beam_search_sample(
|
|||||||
# NOTE: this beam search sample function is wrong now.
|
# NOTE: this beam search sample function is wrong now.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beam_width = generation_config.num_beams
|
|
||||||
results = []
|
results = []
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
# Prompt phase.
|
# Prompt phase.
|
||||||
@ -64,20 +61,8 @@ def beam_search_sample(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False):
|
|
||||||
if generation_config.num_beams == 1:
|
|
||||||
if generation_config.do_sample:
|
|
||||||
sample_tokens = multinomial_sample(generation_config, probs)
|
|
||||||
else:
|
|
||||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
|
||||||
else:
|
|
||||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt)
|
|
||||||
|
|
||||||
return sample_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def search_tokens(
|
def search_tokens(
|
||||||
generation_config: GenerationConfig,
|
generation_config: Union[GenerationConfig, dict],
|
||||||
logits,
|
logits,
|
||||||
is_prompt: bool = False,
|
is_prompt: bool = False,
|
||||||
batch_token_ids: Optional[List[List[int]]] = None,
|
batch_token_ids: Optional[List[List[int]]] = None,
|
||||||
@ -86,23 +71,41 @@ def search_tokens(
|
|||||||
Sample tokens for finished requests.
|
Sample tokens for finished requests.
|
||||||
"""
|
"""
|
||||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||||
config_dict = generation_config.to_dict()
|
|
||||||
# process repetition_penalty, no_repeat_ngram_size
|
|
||||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
|
||||||
if type in config_dict and config_dict[type] is not None:
|
|
||||||
logits = logit_processor(type, logits, config_dict[type], batch_token_ids)
|
|
||||||
|
|
||||||
# do logit processor
|
# convert GenerationConfig to dict
|
||||||
if generation_config.do_sample:
|
# temporary fix for compatibility with the usage of RPCInferenceEngine
|
||||||
# process temperature, top_k, top_p
|
if isinstance(generation_config, GenerationConfig):
|
||||||
for type in ["temperature", "top_k", "top_p"]:
|
generation_config = generation_config.to_dict()
|
||||||
if type in config_dict and config_dict[type] is not None:
|
|
||||||
logits = logit_processor(type, logits, config_dict[type])
|
if (repetition_penalty := generation_config.get("repetition_penalty", 1.0)) != 1.0:
|
||||||
|
logits = get_logits_processor("repetition_penalty", logits, repetition_penalty, batch_token_ids)
|
||||||
|
if (no_repeat_ngram_size := generation_config.get("no_repeat_ngram_size", 0)) > 0:
|
||||||
|
logits = get_logits_processor("no_repeat_ngram_size", logits, no_repeat_ngram_size, batch_token_ids)
|
||||||
|
if (forced_eos_token_id := generation_config.get("forced_eos_token_id", None)) is not None:
|
||||||
|
sequence_lengths = [len(batch_token_ids[i]) for i in range(len(batch_token_ids))]
|
||||||
|
max_out_lengths = [generation_config.max_length for _ in range(len(batch_token_ids))]
|
||||||
|
logits = get_logits_processor(
|
||||||
|
"forced_eos_token_id", logits, sequence_lengths, max_out_lengths, forced_eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if generation_config.get("do_sample"):
|
||||||
|
if (temperature := generation_config.get("temperature", 1.0)) != 1.0:
|
||||||
|
logits = get_logits_processor("temperature", logits, temperature)
|
||||||
|
if (top_k := generation_config.get("top_k", 0)) != 0:
|
||||||
|
logits = get_logits_processor("top_k", logits, top_k)
|
||||||
|
if (top_p := generation_config.get("top_p", 1.0)) < 1.0:
|
||||||
|
logits = get_logits_processor("top_p", logits, top_p)
|
||||||
|
|
||||||
# calculate probs
|
# calculate probs
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
# sample the next tokens
|
# sample the next tokens
|
||||||
sample_tokens = _sample(probs, logprobs, generation_config, is_prompt)
|
if generation_config.get("num_beams", 1) != 1:
|
||||||
|
raise NotImplementedError("Beam search is not supported yet.")
|
||||||
|
if generation_config.get("do_sample", False):
|
||||||
|
sample_tokens = multinomial_sample(probs)
|
||||||
|
else:
|
||||||
|
sample_tokens = greedy_sample(logprobs)
|
||||||
|
|
||||||
return sample_tokens
|
return sample_tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user