[Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708)

* Adapt repetition_penalty and no_repeat_ngram_size

* fix no_repeat_ngram_size_logit_process

* remove batch_updated

* fix annotation

* modified codes based on the review feedback.

* rm get_batch_token_ids
This commit is contained in:
yuehuayingxueluo
2024-05-11 15:13:25 +08:00
committed by GitHub
parent 50104ab340
commit de4bf3dedf
5 changed files with 94 additions and 18 deletions

View File

@@ -424,7 +424,7 @@ class InferenceEngine:
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
@@ -472,7 +472,7 @@ class InferenceEngine:
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
@@ -738,7 +738,7 @@ class InferenceEngine:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()