mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)
* fix flash decoding mask during verification * add spec-dec * add test for spec-dec * revise drafter init * remove drafter sampling * retire past kv in drafter * (trivial) rename attrs * (trivial) rename arg * revise how we enable/disable spec-dec
This commit is contained in:
@@ -44,6 +44,7 @@ def _flash_decoding_fwd_kernel(
|
||||
cur_seq_idx = cur_token_idx // q_len
|
||||
if cur_seq_idx >= batch_size:
|
||||
return
|
||||
cur_token_off = (cur_token_idx % q_len) - q_len + 1
|
||||
cur_head_idx = tl.program_id(1)
|
||||
block_start_kv = tl.program_id(2) # for splitting k/v
|
||||
|
||||
@@ -52,7 +53,8 @@ def _flash_decoding_fwd_kernel(
|
||||
# and then support calculating multiple kv cache blocks on an instance
|
||||
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
||||
# get the current (kv) sequence length
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
||||
# cur_token_off is used as a "mask" here for spec-dec during verification process
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
|
||||
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
||||
return
|
||||
|
||||
@@ -150,7 +152,9 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||
return
|
||||
cur_head_idx = tl.program_id(1)
|
||||
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
||||
# cur_token_off is used as a "mask" here for spec-dec during verification process
|
||||
cur_token_off = (cur_token_idx % q_len) - q_len + 1
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
|
||||
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
|
||||
|
Reference in New Issue
Block a user