[Inference] Fix Inference Generation Config and Sampling (#5710)

* refactor and add

* config default values

* fix gen config passing

* fix rpc generation config
This commit is contained in:
Yuanheng Zhao
2024-05-19 15:08:42 +08:00
committed by GitHub
parent 8bcfe360fd
commit 283c407a19
6 changed files with 124 additions and 68 deletions

View File

@@ -76,6 +76,7 @@ class InferenceEngine:
self.init_model(model_or_path, model_policy)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -524,12 +525,13 @@ class InferenceEngine:
Returns:
List[str]: Inference result returned by one generation.
"""
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode():
if isinstance(prompts, str) and isinstance(request_ids, int):
prompts = [prompts]
request_ids = [request_ids]
if prompts is not None or prompts_token_ids is not None:
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request(
request_ids=request_ids,
prompts=prompts,
@@ -543,6 +545,7 @@ class InferenceEngine:
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
@@ -688,11 +691,12 @@ class InferenceEngine:
)
batch_token_ids = None
config_dict = self.generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
batch_token_ids = batch.batch_token_ids
if (
self.generation_config.repetition_penalty != 1.0
or self.generation_config.no_repeat_ngram_size > 0
or self.generation_config.forced_eos_token_id is not None
):
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False