mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
Merge branch 'feature/colossal-infer' into colossal-infer-cuda-graph
This commit is contained in:
@@ -9,6 +9,7 @@ from transformers.models.llama.modeling_llama import (
|
||||
LlamaForCausalLM,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
)
|
||||
|
||||
from colossalai.inference.config import InputMetaData
|
||||
@@ -19,6 +20,7 @@ from colossalai.kernel.triton import (
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
get_xine_cache,
|
||||
rms_layernorm,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -121,7 +123,8 @@ def llama_model_forward(
|
||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||
residual = residual[last_token_indexs - 1].contiguous()
|
||||
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only
|
||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
|
||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -164,7 +167,7 @@ def llama_decoder_layer_forward(
|
||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
"""
|
||||
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
@@ -182,12 +185,32 @@ def llama_decoder_layer_forward(
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
def llama_rmsnorm_forward(
|
||||
self: LlamaRMSNorm,
|
||||
hidden_states: torch.Tensor,
|
||||
norm_output: torch.Tensor,
|
||||
residual: torch.Tensor = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
):
|
||||
if use_cuda_kernel:
|
||||
if residual is not None:
|
||||
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
|
||||
return hidden_states, residual
|
||||
|
||||
if norm_output is None:
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon)
|
||||
return norm_output, hidden_states
|
||||
else:
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||
|
||||
|
||||
class NopadLlamaAttention(LlamaAttention):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -295,8 +318,12 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
)
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
if use_cuda_kernel:
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
else:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
@@ -312,9 +339,16 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
)
|
||||
else:
|
||||
if use_cuda_kernel:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
)
|
||||
else:
|
||||
decoding_fused_rotary_embedding(
|
||||
|
@@ -1,6 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||
|
||||
@@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
llama_causal_lm_forward,
|
||||
llama_decoder_layer_forward,
|
||||
llama_model_forward,
|
||||
llama_rmsnorm_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
@@ -17,27 +17,6 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
HAS_TRITON_RMSNORM = False
|
||||
|
||||
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(
|
||||
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
|
||||
):
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def __init__(self) -> None:
|
||||
@@ -84,15 +63,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||
)
|
||||
|
||||
infer_forward = None
|
||||
if HAS_TRITON_RMSNORM:
|
||||
infer_forward = get_triton_rmsnorm_forward()
|
||||
|
||||
if infer_forward is not None:
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
|
||||
)
|
||||
infer_forward = llama_rmsnorm_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm)
|
||||
|
||||
return policy
|
||||
|
||||
|
@@ -47,5 +47,5 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
||||
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()
|
||||
|
Reference in New Issue
Block a user