mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708)
* Adapt repetition_penalty and no_repeat_ngram_size * fix no_repeat_ngram_size_logit_process * remove batch_updated * fix annotation * modified codes based on the review feedback. * rm get_batch_token_ids
This commit is contained in:
@@ -102,6 +102,13 @@ class BatchBucket:
|
||||
def num_tokens_to_verify(self) -> int:
|
||||
return self._num_tokens_to_verify
|
||||
|
||||
@property
|
||||
def batch_token_ids(self) -> List[List[int]]:
|
||||
out = []
|
||||
for seq in self.seqs_li:
|
||||
out.append(seq.input_token_id + seq.output_token_id)
|
||||
return out
|
||||
|
||||
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
|
||||
"""Set batch bucket to use speculatvie decoding.
|
||||
This will notify the adjust the lengths of inputs during modeling,
|
||||
@@ -328,6 +335,7 @@ class BatchBucket:
|
||||
seqs.append(seq)
|
||||
if not self.is_compact:
|
||||
self._make_compact()
|
||||
|
||||
return seqs, block_tables
|
||||
|
||||
def pop_finished(
|
||||
@@ -432,6 +440,7 @@ class BatchBucket:
|
||||
block_tables = torch.stack(block_tables_li)
|
||||
self.add_seqs(seqs, alloc_block_tables=block_tables)
|
||||
unmerged_ids = other.seqs_ids
|
||||
|
||||
return unmerged_ids
|
||||
|
||||
########## The following methods are expected to be used in modeling ###########
|
||||
|
Reference in New Issue
Block a user