mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)
* fix flash decoding mask during verification * add spec-dec * add test for spec-dec * revise drafter init * remove drafter sampling * retire past kv in drafter * (trivial) rename attrs * (trivial) rename arg * revise how we enable/disable spec-dec
This commit is contained in:
@@ -18,6 +18,7 @@ from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
get_xine_cache,
|
||||
@@ -84,9 +85,9 @@ def llama_model_forward(
|
||||
"""This function will replace the forward function of LlamaModel.
|
||||
|
||||
Args:
|
||||
batch (BatchInfo): It stores the necessary input information for this inference.
|
||||
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
|
||||
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
|
||||
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
|
||||
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
block_tables = inputmetadata.block_tables
|
||||
@@ -101,7 +102,25 @@ def llama_model_forward(
|
||||
use_cuda_kernel = False
|
||||
|
||||
hidden_states = self.embed_tokens(input_tokens_ids)
|
||||
if use_cuda_kernel:
|
||||
cu_seqlens = None
|
||||
|
||||
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
|
||||
if inputmetadata.use_spec_dec:
|
||||
# For speculative-decoding Prefill and Verifying Stage
|
||||
if inputmetadata.is_prompts:
|
||||
# output tensor shape is the same as normal Prefill Stage
|
||||
o_tensor_size = (sequence_lengths.sum().item(), inputmetadata.num_heads * inputmetadata.head_dim)
|
||||
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
|
||||
else:
|
||||
# the number of tokens to be verified in parallel plus the correct token in the last step
|
||||
n_tokens = inputmetadata.num_tokens_to_verify + 1
|
||||
assert n_tokens == hidden_states.size(0)
|
||||
o_tensor_size = (batch_size * n_tokens, inputmetadata.num_heads * inputmetadata.head_dim)
|
||||
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
|
||||
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
|
||||
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||
|
||||
elif use_cuda_kernel:
|
||||
if inputmetadata != torch.float32 and use_flash_attn2:
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
@@ -113,14 +132,22 @@ def llama_model_forward(
|
||||
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
|
||||
)
|
||||
cos_sin = (cos, sin)
|
||||
|
||||
else:
|
||||
cu_seqlens = None
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
||||
|
||||
# TODO (yuanheng-zhao): revise the logic here
|
||||
# if batch.is_prompts:
|
||||
# output_tensor = torch.zeros(
|
||||
# (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
# )
|
||||
# else:
|
||||
# output_tensor = torch.zeros(
|
||||
# (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
# )
|
||||
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
||||
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
tokens_to_verify = inputmetadata.num_tokens_to_verify if inputmetadata.use_spec_dec else None
|
||||
residual = None
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
@@ -131,6 +158,8 @@ def llama_model_forward(
|
||||
k_cache=k_caches[layer_id],
|
||||
v_cache=v_caches[layer_id],
|
||||
is_prompts=inputmetadata.is_prompts,
|
||||
is_verifier=inputmetadata.use_spec_dec,
|
||||
tokens_to_verify=tokens_to_verify,
|
||||
sequence_lengths=sequence_lengths,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=inputmetadata.fd_inter_tensor,
|
||||
@@ -144,9 +173,9 @@ def llama_model_forward(
|
||||
)
|
||||
|
||||
if inputmetadata.is_prompts:
|
||||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||
residual = residual[last_token_indexs - 1].contiguous()
|
||||
seq_len_cumsum = sequence_lengths.cumsum(dim=0)
|
||||
hidden_states = hidden_states[seq_len_cumsum - 1].contiguous()
|
||||
residual = residual[seq_len_cumsum - 1].contiguous()
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||
|
||||
@@ -164,6 +193,8 @@ def llama_decoder_layer_forward(
|
||||
cos_sin: Tuple[torch.Tensor],
|
||||
fd_inter_tensor: FDIntermTensors,
|
||||
is_prompts: bool = True,
|
||||
is_verifier: bool = False,
|
||||
tokens_to_verify: int = None,
|
||||
kv_seq_len: int = 0,
|
||||
output_tensor: torch.Tensor = None,
|
||||
norm_output: torch.Tensor = None,
|
||||
@@ -202,6 +233,9 @@ def llama_decoder_layer_forward(
|
||||
block_tables=block_tables,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
is_prompts=is_prompts,
|
||||
is_verifier=is_verifier,
|
||||
tokens_to_verify=tokens_to_verify,
|
||||
sequence_lengths=sequence_lengths,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
@@ -312,6 +346,8 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
cos_sin: Tuple[torch.Tensor],
|
||||
fd_inter_tensor: FDIntermTensors,
|
||||
is_prompts: bool = True,
|
||||
is_verifier: bool = False,
|
||||
tokens_to_verify: int = None,
|
||||
kv_seq_len: int = 0,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
@@ -355,7 +391,7 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||
# flash attn 2 currently only supports FP16/BF16.
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
@@ -405,17 +441,27 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
high_precision,
|
||||
)
|
||||
else:
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
if is_verifier:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
copy_k_to_blocked_cache(
|
||||
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
else:
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
@@ -428,8 +474,10 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||
|
||||
return attn_output
|
||||
|
Reference in New Issue
Block a user