[Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708)

* Adapt repetition_penalty and no_repeat_ngram_size

* fix no_repeat_ngram_size_logit_process

* remove batch_updated

* fix annotation

* modified codes based on the review feedback.

* rm get_batch_token_ids
This commit is contained in:
yuehuayingxueluo
2024-05-11 15:13:25 +08:00
committed by GitHub
parent 50104ab340
commit de4bf3dedf
5 changed files with 94 additions and 18 deletions

View File

@@ -102,6 +102,13 @@ class BatchBucket:
def num_tokens_to_verify(self) -> int:
return self._num_tokens_to_verify
@property
def batch_token_ids(self) -> List[List[int]]:
out = []
for seq in self.seqs_li:
out.append(seq.input_token_id + seq.output_token_id)
return out
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
@@ -328,6 +335,7 @@ class BatchBucket:
seqs.append(seq)
if not self.is_compact:
self._make_compact()
return seqs, block_tables
def pop_finished(
@@ -432,6 +440,7 @@ class BatchBucket:
block_tables = torch.stack(block_tables_li)
self.add_seqs(seqs, alloc_block_tables=block_tables)
unmerged_ids = other.seqs_ids
return unmerged_ids
########## The following methods are expected to be used in modeling ###########