mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708)
* Adapt repetition_penalty and no_repeat_ngram_size * fix no_repeat_ngram_size_logit_process * remove batch_updated * fix annotation * modified codes based on the review feedback. * rm get_batch_token_ids
This commit is contained in:
@@ -424,7 +424,7 @@ class InferenceEngine:
|
||||
|
||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
||||
# append new inputs to the batch, temporarily
|
||||
batch.append_batch_tokens(next_tokens)
|
||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||
@@ -472,7 +472,7 @@ class InferenceEngine:
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
||||
|
||||
# 5. Compare and process the results
|
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||
@@ -738,7 +738,7 @@ class InferenceEngine:
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
|
@@ -11,12 +11,9 @@ from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.struct import RequestStatus, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
__all__ = ["RunningList", "RequestHandler"]
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
class RunningList:
|
||||
"""
|
||||
@@ -331,15 +328,21 @@ class RequestHandler:
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
|
||||
"""
|
||||
Sample tokens for finished requests.
|
||||
"""
|
||||
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
config_dict = generation_config.to_dict()
|
||||
# process repetition_penalty, no_repeat_ngram_size
|
||||
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type], cur_batch)
|
||||
|
||||
# do logit processor
|
||||
if generation_config.do_sample:
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
config_dict = generation_config.to_dict()
|
||||
# process temperature, top_k, top_p
|
||||
for type in ["temperature", "top_k", "top_p"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type])
|
||||
|
Reference in New Issue
Block a user