mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[Inference] Fix Inference Generation Config and Sampling (#5710)
* refactor and add * config default values * fix gen config passing * fix rpc generation config
This commit is contained in:
@@ -76,6 +76,7 @@ class InferenceEngine:
|
||||
self.init_model(model_or_path, model_policy)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
@@ -524,12 +525,13 @@ class InferenceEngine:
|
||||
Returns:
|
||||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||
|
||||
with torch.inference_mode():
|
||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
||||
prompts = [prompts]
|
||||
request_ids = [request_ids]
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
@@ -543,6 +545,7 @@ class InferenceEngine:
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
self.generation_config_dict = gen_config_dict
|
||||
|
||||
if self.use_spec_dec:
|
||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||
@@ -688,11 +691,12 @@ class InferenceEngine:
|
||||
)
|
||||
|
||||
batch_token_ids = None
|
||||
config_dict = self.generation_config.to_dict()
|
||||
# process repetition_penalty, no_repeat_ngram_size
|
||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
if (
|
||||
self.generation_config.repetition_penalty != 1.0
|
||||
or self.generation_config.no_repeat_ngram_size > 0
|
||||
or self.generation_config.forced_eos_token_id is not None
|
||||
):
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph = False
|
||||
|
Reference in New Issue
Block a user