mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)
* fix flash decoding mask during verification * add spec-dec * add test for spec-dec * revise drafter init * remove drafter sampling * retire past kv in drafter * (trivial) rename attrs * (trivial) rename arg * revise how we enable/disable spec-dec
This commit is contained in:
@@ -134,8 +134,12 @@ class RequestHandler:
|
||||
if fd_inter_tensor._tensors_initialized:
|
||||
fd_inter_tensor._reset()
|
||||
|
||||
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
|
||||
max_n_tokens = self.max_batch_size
|
||||
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
|
||||
|
||||
fd_inter_tensor.initialize(
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_batch_size=max_n_tokens,
|
||||
num_attn_heads=model_config.num_attention_heads,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
head_dim=head_dim,
|
||||
@@ -230,6 +234,13 @@ class RequestHandler:
|
||||
|
||||
return self.running_bb
|
||||
|
||||
def allocate_batch_spec_dec(self, batch: BatchBucket, n: int):
|
||||
assert batch.use_spec_dec
|
||||
if n > 0:
|
||||
self.cache_manager.allocate_n_tokens_from_block_tables(
|
||||
batch.block_tables, batch.seq_lengths, batch.current_batch_size, n=n
|
||||
)
|
||||
|
||||
def add_sequence(self, req: Sequence):
|
||||
"""
|
||||
Add the request to waiting list.
|
||||
@@ -282,13 +293,21 @@ class RequestHandler:
|
||||
|
||||
return sample_tokens
|
||||
|
||||
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
||||
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
||||
if (
|
||||
sequence.output_token_id[-1] == generation_config.eos_id
|
||||
or sequence.output_len >= generation_config.max_output_len
|
||||
sequence.output_token_id[-1] == generation_config.eos_token_id
|
||||
or sequence.output_len >= generation_config.max_length
|
||||
):
|
||||
sequence.mark_finished()
|
||||
|
||||
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
|
||||
for seq in batch.seqs_li:
|
||||
if (
|
||||
seq.output_token_id[-1] == generation_config.eos_token_id
|
||||
or seq.output_len >= generation_config.max_length
|
||||
):
|
||||
seq.mark_finished()
|
||||
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
@@ -309,9 +328,20 @@ class RequestHandler:
|
||||
|
||||
# sample the next tokens
|
||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||
return sample_tokens
|
||||
|
||||
def append_next_tokens(self, sample_tokens: torch.Tensor):
|
||||
assert sample_tokens.dim() == 1
|
||||
n_elements = sample_tokens.size(0)
|
||||
if not self.prefill_bb.is_empty:
|
||||
assert (
|
||||
self.prefill_bb.current_batch_size == n_elements
|
||||
), f"Incompatible size: {n_elements} tokens to append while prefill batch size {self.prefill_bb.current_batch_size}"
|
||||
self.prefill_bb.append_batch_tokens(sample_tokens)
|
||||
else:
|
||||
assert (
|
||||
self.running_bb.current_batch_size == n_elements
|
||||
), f"Incompatible size: {n_elements} tokens to append while running batch size {self.running_bb.current_batch_size}"
|
||||
self.running_bb.append_batch_tokens(sample_tokens)
|
||||
|
||||
def update(self):
|
||||
|
Reference in New Issue
Block a user