mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Inference] Fix Inference Generation Config and Sampling (#5710)
* refactor and add * config default values * fix gen config passing * fix rpc generation config
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.logit_processors import get_logits_processor
|
||||
|
||||
|
||||
def greedy_sample(
|
||||
generation_config,
|
||||
logprobs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -18,7 +17,6 @@ def greedy_sample(
|
||||
|
||||
|
||||
def multinomial_sample(
|
||||
generation_config,
|
||||
probs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -29,7 +27,7 @@ def multinomial_sample(
|
||||
|
||||
|
||||
def beam_search_sample(
|
||||
generation_config,
|
||||
beam_width: int,
|
||||
logprobs: torch.Tensor,
|
||||
is_prompt: bool = False,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
@@ -46,7 +44,6 @@ def beam_search_sample(
|
||||
# NOTE: this beam search sample function is wrong now.
|
||||
"""
|
||||
|
||||
beam_width = generation_config.num_beams
|
||||
results = []
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
@@ -64,20 +61,8 @@ def beam_search_sample(
|
||||
return results
|
||||
|
||||
|
||||
def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False):
|
||||
if generation_config.num_beams == 1:
|
||||
if generation_config.do_sample:
|
||||
sample_tokens = multinomial_sample(generation_config, probs)
|
||||
else:
|
||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||
else:
|
||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt)
|
||||
|
||||
return sample_tokens
|
||||
|
||||
|
||||
def search_tokens(
|
||||
generation_config: GenerationConfig,
|
||||
generation_config: Union[GenerationConfig, dict],
|
||||
logits,
|
||||
is_prompt: bool = False,
|
||||
batch_token_ids: Optional[List[List[int]]] = None,
|
||||
@@ -86,23 +71,41 @@ def search_tokens(
|
||||
Sample tokens for finished requests.
|
||||
"""
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
config_dict = generation_config.to_dict()
|
||||
# process repetition_penalty, no_repeat_ngram_size
|
||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type], batch_token_ids)
|
||||
|
||||
# do logit processor
|
||||
if generation_config.do_sample:
|
||||
# process temperature, top_k, top_p
|
||||
for type in ["temperature", "top_k", "top_p"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type])
|
||||
# convert GenerationConfig to dict
|
||||
# temporary fix for compatibility with the usage of RPCInferenceEngine
|
||||
if isinstance(generation_config, GenerationConfig):
|
||||
generation_config = generation_config.to_dict()
|
||||
|
||||
if (repetition_penalty := generation_config.get("repetition_penalty", 1.0)) != 1.0:
|
||||
logits = get_logits_processor("repetition_penalty", logits, repetition_penalty, batch_token_ids)
|
||||
if (no_repeat_ngram_size := generation_config.get("no_repeat_ngram_size", 0)) > 0:
|
||||
logits = get_logits_processor("no_repeat_ngram_size", logits, no_repeat_ngram_size, batch_token_ids)
|
||||
if (forced_eos_token_id := generation_config.get("forced_eos_token_id", None)) is not None:
|
||||
sequence_lengths = [len(batch_token_ids[i]) for i in range(len(batch_token_ids))]
|
||||
max_out_lengths = [generation_config.max_length for _ in range(len(batch_token_ids))]
|
||||
logits = get_logits_processor(
|
||||
"forced_eos_token_id", logits, sequence_lengths, max_out_lengths, forced_eos_token_id
|
||||
)
|
||||
|
||||
if generation_config.get("do_sample"):
|
||||
if (temperature := generation_config.get("temperature", 1.0)) != 1.0:
|
||||
logits = get_logits_processor("temperature", logits, temperature)
|
||||
if (top_k := generation_config.get("top_k", 0)) != 0:
|
||||
logits = get_logits_processor("top_k", logits, top_k)
|
||||
if (top_p := generation_config.get("top_p", 1.0)) < 1.0:
|
||||
logits = get_logits_processor("top_p", logits, top_p)
|
||||
|
||||
# calculate probs
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# sample the next tokens
|
||||
sample_tokens = _sample(probs, logprobs, generation_config, is_prompt)
|
||||
if generation_config.get("num_beams", 1) != 1:
|
||||
raise NotImplementedError("Beam search is not supported yet.")
|
||||
if generation_config.get("do_sample", False):
|
||||
sample_tokens = multinomial_sample(probs)
|
||||
else:
|
||||
sample_tokens = greedy_sample(logprobs)
|
||||
|
||||
return sample_tokens
|
||||
|
Reference in New Issue
Block a user