Merge branch 'feature/colossal-infer' into colossal-infer-cuda-graph

This commit is contained in:
Runyu Lu
2024-03-14 10:37:05 +08:00
committed by GitHub
53 changed files with 2133 additions and 252 deletions

View File

@@ -9,6 +9,7 @@ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
)
from colossalai.inference.config import InputMetaData
@@ -19,6 +20,7 @@ from colossalai.kernel.triton import (
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
rms_layernorm,
rotary_embedding,
)
from colossalai.logging import get_dist_logger
@@ -121,7 +123,8 @@ def llama_model_forward(
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous()
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
return hidden_states
@@ -164,7 +167,7 @@ def llama_decoder_layer_forward(
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
"""
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -182,12 +185,32 @@ def llama_decoder_layer_forward(
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def llama_rmsnorm_forward(
self: LlamaRMSNorm,
hidden_states: torch.Tensor,
norm_output: torch.Tensor,
residual: torch.Tensor = None,
use_cuda_kernel: bool = True,
):
if use_cuda_kernel:
if residual is not None:
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
return hidden_states, residual
if norm_output is None:
norm_output = torch.empty_like(hidden_states)
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon)
return norm_output, hidden_states
else:
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
class NopadLlamaAttention(LlamaAttention):
def __init__(
self,
@@ -295,8 +318,12 @@ class NopadLlamaAttention(LlamaAttention):
)
block_size = k_cache.size(-2)
if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
if use_cuda_kernel:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
else:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
@@ -312,9 +339,16 @@ class NopadLlamaAttention(LlamaAttention):
)
else:
if use_cuda_kernel:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
)
else:
decoding_fused_rotary_embedding(