[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)

* fix flash decoding mask during verification

* add spec-dec

* add test for spec-dec

* revise drafter init

* remove drafter sampling

* retire past kv in drafter

* (trivial) rename attrs

* (trivial) rename arg

* revise how we enable/disable spec-dec
This commit is contained in:
Yuanheng Zhao
2024-03-11 09:51:42 +08:00
committed by Yuanheng
parent 5a9b05f7b2
commit a37f82629d
11 changed files with 484 additions and 133 deletions

View File

@@ -134,8 +134,12 @@ class RequestHandler:
if fd_inter_tensor._tensors_initialized:
fd_inter_tensor._reset()
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
max_n_tokens = self.max_batch_size
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
max_batch_size=max_n_tokens,
num_attn_heads=model_config.num_attention_heads,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
@@ -230,6 +234,13 @@ class RequestHandler:
return self.running_bb
def allocate_batch_spec_dec(self, batch: BatchBucket, n: int):
assert batch.use_spec_dec
if n > 0:
self.cache_manager.allocate_n_tokens_from_block_tables(
batch.block_tables, batch.seq_lengths, batch.current_batch_size, n=n
)
def add_sequence(self, req: Sequence):
"""
Add the request to waiting list.
@@ -282,13 +293,21 @@ class RequestHandler:
return sample_tokens
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig):
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
if (
sequence.output_token_id[-1] == generation_config.eos_id
or sequence.output_len >= generation_config.max_output_len
sequence.output_token_id[-1] == generation_config.eos_token_id
or sequence.output_len >= generation_config.max_length
):
sequence.mark_finished()
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
for seq in batch.seqs_li:
if (
seq.output_token_id[-1] == generation_config.eos_token_id
or seq.output_len >= generation_config.max_length
):
seq.mark_finished()
def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
@@ -309,9 +328,20 @@ class RequestHandler:
# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
return sample_tokens
def append_next_tokens(self, sample_tokens: torch.Tensor):
assert sample_tokens.dim() == 1
n_elements = sample_tokens.size(0)
if not self.prefill_bb.is_empty:
assert (
self.prefill_bb.current_batch_size == n_elements
), f"Incompatible size: {n_elements} tokens to append while prefill batch size {self.prefill_bb.current_batch_size}"
self.prefill_bb.append_batch_tokens(sample_tokens)
else:
assert (
self.running_bb.current_batch_size == n_elements
), f"Incompatible size: {n_elements} tokens to append while running batch size {self.running_bb.current_batch_size}"
self.running_bb.append_batch_tokens(sample_tokens)
def update(self):