Merge branch 'feature/colossal-infer' into colossal-infer-cuda-graph

This commit is contained in:
Runyu Lu
2024-03-14 10:37:05 +08:00
committed by GitHub
53 changed files with 2133 additions and 252 deletions

View File

@@ -9,6 +9,7 @@ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
)
from colossalai.inference.config import InputMetaData
@@ -19,6 +20,7 @@ from colossalai.kernel.triton import (
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
rms_layernorm,
rotary_embedding,
)
from colossalai.logging import get_dist_logger
@@ -121,7 +123,8 @@ def llama_model_forward(
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous()
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
return hidden_states
@@ -164,7 +167,7 @@ def llama_decoder_layer_forward(
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
"""
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -182,12 +185,32 @@ def llama_decoder_layer_forward(
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def llama_rmsnorm_forward(
self: LlamaRMSNorm,
hidden_states: torch.Tensor,
norm_output: torch.Tensor,
residual: torch.Tensor = None,
use_cuda_kernel: bool = True,
):
if use_cuda_kernel:
if residual is not None:
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
return hidden_states, residual
if norm_output is None:
norm_output = torch.empty_like(hidden_states)
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon)
return norm_output, hidden_states
else:
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
class NopadLlamaAttention(LlamaAttention):
def __init__(
self,
@@ -295,8 +318,12 @@ class NopadLlamaAttention(LlamaAttention):
)
block_size = k_cache.size(-2)
if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
if use_cuda_kernel:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
else:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
@@ -312,9 +339,16 @@ class NopadLlamaAttention(LlamaAttention):
)
else:
if use_cuda_kernel:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
)
else:
decoding_fused_rotary_embedding(

View File

@@ -1,6 +1,5 @@
from functools import partial
import torch
from torch.nn import Parameter
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
@@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
llama_rmsnorm_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
@@ -17,27 +17,6 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
try:
from colossalai.kernel.triton import rms_layernorm
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
return _triton_rmsnorm_forward
else:
return None
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
@@ -84,15 +63,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
)
infer_forward = llama_rmsnorm_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm)
return policy

View File

@@ -47,5 +47,5 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()