[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)

* fix flash decoding mask during verification

* add spec-dec

* add test for spec-dec

* revise drafter init

* remove drafter sampling

* retire past kv in drafter

* (trivial) rename attrs

* (trivial) rename arg

* revise how we enable/disable spec-dec
This commit is contained in:
Yuanheng Zhao
2024-03-11 09:51:42 +08:00
committed by Yuanheng
parent 5a9b05f7b2
commit a37f82629d
11 changed files with 484 additions and 133 deletions

View File

@@ -18,6 +18,7 @@ from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
@@ -84,9 +85,9 @@ def llama_model_forward(
"""This function will replace the forward function of LlamaModel.
Args:
batch (BatchInfo): It stores the necessary input information for this inference.
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
block_tables = inputmetadata.block_tables
@@ -101,7 +102,25 @@ def llama_model_forward(
use_cuda_kernel = False
hidden_states = self.embed_tokens(input_tokens_ids)
if use_cuda_kernel:
cu_seqlens = None
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
if inputmetadata.use_spec_dec:
# For speculative-decoding Prefill and Verifying Stage
if inputmetadata.is_prompts:
# output tensor shape is the same as normal Prefill Stage
o_tensor_size = (sequence_lengths.sum().item(), inputmetadata.num_heads * inputmetadata.head_dim)
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
else:
# the number of tokens to be verified in parallel plus the correct token in the last step
n_tokens = inputmetadata.num_tokens_to_verify + 1
assert n_tokens == hidden_states.size(0)
o_tensor_size = (batch_size * n_tokens, inputmetadata.num_heads * inputmetadata.head_dim)
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
elif use_cuda_kernel:
if inputmetadata != torch.float32 and use_flash_attn2:
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
@@ -113,14 +132,22 @@ def llama_model_forward(
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
)
cos_sin = (cos, sin)
else:
cu_seqlens = None
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
# TODO (yuanheng-zhao): revise the logic here
# if batch.is_prompts:
# output_tensor = torch.zeros(
# (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
# )
# else:
# output_tensor = torch.zeros(
# (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
# )
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
norm_output = torch.empty_like(hidden_states)
tokens_to_verify = inputmetadata.num_tokens_to_verify if inputmetadata.use_spec_dec else None
residual = None
for layer_id, decoder_layer in enumerate(self.layers):
@@ -131,6 +158,8 @@ def llama_model_forward(
k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id],
is_prompts=inputmetadata.is_prompts,
is_verifier=inputmetadata.use_spec_dec,
tokens_to_verify=tokens_to_verify,
sequence_lengths=sequence_lengths,
cos_sin=cos_sin,
fd_inter_tensor=inputmetadata.fd_inter_tensor,
@@ -144,9 +173,9 @@ def llama_model_forward(
)
if inputmetadata.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous()
seq_len_cumsum = sequence_lengths.cumsum(dim=0)
hidden_states = hidden_states[seq_len_cumsum - 1].contiguous()
residual = residual[seq_len_cumsum - 1].contiguous()
norm_output = torch.empty_like(hidden_states)
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
@@ -164,6 +193,8 @@ def llama_decoder_layer_forward(
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True,
is_verifier: bool = False,
tokens_to_verify: int = None,
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
@@ -202,6 +233,9 @@ def llama_decoder_layer_forward(
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
is_prompts=is_prompts,
is_verifier=is_verifier,
tokens_to_verify=tokens_to_verify,
sequence_lengths=sequence_lengths,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
@@ -312,6 +346,8 @@ class NopadLlamaAttention(LlamaAttention):
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True,
is_verifier: bool = False,
tokens_to_verify: int = None,
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
@@ -355,7 +391,7 @@ class NopadLlamaAttention(LlamaAttention):
block_size = k_cache.size(-2)
if is_prompts:
if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
# flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
@@ -405,17 +441,27 @@ class NopadLlamaAttention(LlamaAttention):
high_precision,
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
q_len = tokens_to_verify + 1 if is_verifier else 1
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
@@ -428,8 +474,10 @@ class NopadLlamaAttention(LlamaAttention):
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
q_len=q_len,
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight)
return attn_output