[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)

* fix flash decoding mask during verification

* add spec-dec

* add test for spec-dec

* revise drafter init

* remove drafter sampling

* retire past kv in drafter

* (trivial) rename attrs

* (trivial) rename arg

* revise how we enable/disable spec-dec
This commit is contained in:
Yuanheng Zhao
2024-03-11 09:51:42 +08:00
committed by Yuanheng
parent 5a9b05f7b2
commit a37f82629d
11 changed files with 484 additions and 133 deletions

View File

@@ -44,6 +44,7 @@ def _flash_decoding_fwd_kernel(
cur_seq_idx = cur_token_idx // q_len
if cur_seq_idx >= batch_size:
return
cur_token_off = (cur_token_idx % q_len) - q_len + 1
cur_head_idx = tl.program_id(1)
block_start_kv = tl.program_id(2) # for splitting k/v
@@ -52,7 +53,8 @@ def _flash_decoding_fwd_kernel(
# and then support calculating multiple kv cache blocks on an instance
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
# get the current (kv) sequence length
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
# cur_token_off is used as a "mask" here for spec-dec during verification process
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
return
@@ -150,7 +152,9 @@ def _flash_decoding_fwd_reduce_kernel(
return
cur_head_idx = tl.program_id(1)
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
# cur_token_off is used as a "mask" here for spec-dec during verification process
cur_token_off = (cur_token_idx % q_len) - q_len + 1
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
offsets_dmodel = tl.arange(0, HEAD_DIM)
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have