[Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708)

* Adapt repetition_penalty and no_repeat_ngram_size

* fix no_repeat_ngram_size_logit_process

* remove batch_updated

* fix annotation

* modified codes based on the review feedback.

* rm get_batch_token_ids
This commit is contained in:
yuehuayingxueluo
2024-05-11 15:13:25 +08:00
committed by GitHub
parent 50104ab340
commit de4bf3dedf
5 changed files with 94 additions and 18 deletions

View File

@@ -11,12 +11,9 @@ from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.logging import get_dist_logger
__all__ = ["RunningList", "RequestHandler"]
logger = get_dist_logger(__name__)
class RunningList:
"""
@@ -331,15 +328,21 @@ class RequestHandler:
def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits):
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
"""
Sample tokens for finished requests.
"""
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type], cur_batch)
# do logit processor
if generation_config.do_sample:
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process temperature, top_k, top_p
for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])