diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 4dfe6dbd7..5fa1e7161 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -2,7 +2,6 @@ from typing import List, Optional, Tuple import torch -from torch.nn import Parameter from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, @@ -82,19 +81,21 @@ def llama_model_forward( if batch.is_prompts: output_tensor = torch.zeros( - (sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device + (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device ) else: output_tensor = torch.zeros( - (batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device + (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device ) sm_scale = 1.0 / (batch.head_dim**0.5) norm_output = torch.empty_like(hidden_states) + residual = None for layer_id, decoder_layer in enumerate(self.layers): - hidden_states = decoder_layer( + hidden_states, residual = decoder_layer( hidden_states, + residual=residual, block_tables=block_tables, k_cache=k_caches[layer_id], v_cache=v_caches[layer_id], @@ -111,8 +112,9 @@ def llama_model_forward( if batch.is_prompts: last_token_indexs = sequence_lengths.cumsum(dim=-1) hidden_states = hidden_states[last_token_indexs - 1].contiguous() + residual = residual[last_token_indexs - 1].contiguous() norm_output = torch.empty_like(hidden_states) - hidden_states = self.norm(hidden_states, norm_output) + hidden_states, _ = self.norm(hidden_states, norm_output, residual) return hidden_states @@ -120,6 +122,7 @@ def llama_model_forward( def llama_decoder_layer_forward( self: LlamaDecoderLayer, hidden_states: torch.Tensor, + residual: torch.Tensor, block_tables: torch.Tensor = None, k_cache: torch.Tensor = None, v_cache: torch.Tensor = None, @@ -136,6 +139,7 @@ def llama_decoder_layer_forward( Args: hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. + residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj. block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], storing mapping of token_position_id -> block_id. Defaults to None. k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. @@ -151,12 +155,10 @@ def llama_decoder_layer_forward( sm_scale (int, optional): Used for flash attention. Defaults to None. """ - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states, norm_output) + hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual) # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, - residual=residual, block_tables=block_tables, k_cache=k_cache, v_cache=v_cache, @@ -170,11 +172,10 @@ def llama_decoder_layer_forward( ) # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states, norm_output) - hidden_states = self.mlp(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual) + hidden_states = self.mlp(hidden_states) - return hidden_states + return hidden_states, residual class NopadLlamaAttention(LlamaAttention): @@ -198,16 +199,18 @@ class NopadLlamaAttention(LlamaAttention): attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None. """ super().__init__(config, layer_idx) - self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False) - self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False) - self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False) - self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False) + self.q_proj_weight = attn_qproj_w + self.k_proj_weight = attn_kproj_w + self.v_proj_weight = attn_vproj_w + self.o_proj_weight = attn_oproj_w + if self.num_heads == self.num_key_value_heads: - qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight] + qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight] self.qkv_weight = torch.stack(qkv_weight_list, dim=0) - self.q_proj = None - self.k_proj = None - self.v_proj = None + + self.q_proj = None + self.k_proj = None + self.v_proj = None @staticmethod def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention: @@ -239,7 +242,6 @@ class NopadLlamaAttention(LlamaAttention): def forward( self, hidden_states: torch.Tensor, - residual: torch.Tensor, block_tables: torch.Tensor = None, k_cache: torch.Tensor = None, v_cache: torch.Tensor = None, @@ -254,7 +256,6 @@ class NopadLlamaAttention(LlamaAttention): """ Args: hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. - residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj. block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], storing mapping of token_position_id -> block_id. Defaults to None. k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. @@ -270,9 +271,9 @@ class NopadLlamaAttention(LlamaAttention): """ if self.num_heads != self.num_key_value_heads: - query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim) - key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim) - value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim) + query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim) + key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) + value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) else: # fused qkv token_nums = hidden_states.size(0) @@ -324,8 +325,7 @@ class NopadLlamaAttention(LlamaAttention): sm_scale=sm_scale, ) - attn_output = attn_output.view(-1, self.hidden_size) - attn_output = torch.addmm(residual, attn_output, self.o_proj.weight) + attn_output = torch.mm(attn_output, self.o_proj_weight) return attn_output @@ -348,10 +348,11 @@ class NopadLlamaMLP(LlamaMLP): mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None. """ super().__init__(config) - self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False) - self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False) + self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0) + self.down_proj_weight = mlp_dproj_w self.gate_proj = None self.up_proj = None + self.down_proj = None @staticmethod def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP: @@ -375,14 +376,13 @@ class NopadLlamaMLP(LlamaMLP): return mlp_layer - def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Args: hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. - residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj. """ hidden_states = hidden_states.expand(2, -1, -1) gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True) tmp_out = act_out * gate_up_proj_out[1] - return torch.addmm(residual, tmp_out, self.down_proj.weight) + return torch.mm(tmp_out, self.down_proj_weight) diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index c8bb7dae3..13695b835 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -29,8 +29,10 @@ except: def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: - def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor): - return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output) + def _triton_rmsnorm_forward( + self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None + ): + return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual) return _triton_rmsnorm_forward else: diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index 68baffd53..3f494b97f 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -205,7 +205,7 @@ def context_attention_unpadded( assert k_cache.shape == v_cache.shape assert context_lengths.shape[0] == block_tables.shape[0] - num_tokens, num_heads, _ = q.shape + num_tokens, num_heads, head_dim = q.shape num_kv_heads = k.shape[-2] assert num_kv_heads > 0 and num_heads % num_kv_heads == 0 num_kv_group = num_heads // num_kv_heads @@ -213,7 +213,9 @@ def context_attention_unpadded( num_seqs, max_blocks_per_seq = block_tables.shape max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale - output = torch.zeros_like(q) if output is None else output + output = ( + torch.empty((num_tokens, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output + ) # NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with # the size of physical cache block (i.e. `block_size`) @@ -243,8 +245,8 @@ def context_attention_unpadded( v.stride(1), v.stride(2), output.stride(0), - output.stride(1), - output.stride(2), + head_dim, + 1, k_cache.stride(0), k_cache.stride(1), k_cache.stride(2), diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index 07351d023..d351b20da 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -211,7 +211,7 @@ def flash_decoding_attention( records the (kv) sequence lengths incorporating past kv sequence lengths. block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] max_seq_len_in_batch (int): Maximum sequence length in the batch. - output (torch.Tensor): [bsz, num_heads, head_dim] + output (torch.Tensor): [bsz, num_heads * head_dim] mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim] Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num] @@ -220,7 +220,7 @@ def flash_decoding_attention( num_kv_group (int, optional): Number of key/value groups. Defaults to 1. Returns: - Output tensor with shape [bsz, num_heads, head_dim] + Output tensor with shape [bsz, num_heads * head_dim] """ q = q.squeeze() if q.dim() == 4 else q assert q.dim() == 3, f"Incompatible q dim: {q.dim()}" @@ -261,7 +261,7 @@ def flash_decoding_attention( # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV)) - output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output + output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output _flash_decoding_fwd_kernel[grid]( q, @@ -294,7 +294,7 @@ def flash_decoding_attention( BLOCK_SIZE=block_size, HEAD_DIM=head_dim, ) - + grid = (triton.next_power_of_2(bsz), num_heads) _flash_decoding_fwd_reduce_kernel[grid]( @@ -311,8 +311,8 @@ def flash_decoding_attention( mid_output_lse.stride(1), mid_output_lse.stride(2), output.stride(0), - output.stride(1), - output.stride(2), + head_dim, + 1, BLOCK_KV=block_size, HEAD_DIM=head_dim, ) diff --git a/colossalai/kernel/triton/rms_layernorm.py b/colossalai/kernel/triton/rms_layernorm.py index fb4fa02bc..dcf478561 100644 --- a/colossalai/kernel/triton/rms_layernorm.py +++ b/colossalai/kernel/triton/rms_layernorm.py @@ -49,7 +49,50 @@ if HAS_TRITON: # Write output tl.store(Y + cols, y.to(tl.float16), mask=mask) - def rms_layernorm(x, weight, eps, norm_output=None): + @triton.jit + def _rmsnorm_with_residual_kernel( + X, # pointer to the input + Y, # pointer to the output + R, # pointer to the residual + W, # pointer to the weights + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, + ): + # This triton kernel implements Root Mean Square Layer Norm (RMSNorm). + + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + R += row * stride + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x, 0.0) + r = tl.load(R + cols, mask=cols < N, other=0.0).to(tl.float32) + r = tl.where(cols < N, r, 0.0) + x = x + r + _var += x * x + mask = cols < N + tl.store(X + cols, x.to(tl.float16), mask=mask) + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rstd + y = x_hat * w + # Write output + tl.store(Y + cols, y.to(tl.float16), mask=mask) + + def rms_layernorm(x, weight, eps, norm_output=None, residual=None): # allocate output y = torch.empty_like(x) if norm_output is None else norm_output M, N = x.shape @@ -64,5 +107,10 @@ if HAS_TRITON: num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32) # enqueue kernel - _rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) - return y + if residual is None: + _rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) + else: + _rmsnorm_with_residual_kernel[(M,)]( + x, y, residual, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps + ) + return y, x diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 8098f4891..a6cbf2ee1 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -95,7 +95,7 @@ def benchmark_inference(args): else: assert args.model_path, "When testing pretrained weights, the model path must be provided.'" model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda() - tokenizer = AutoTokenizer.from_pretrained(args.model_path) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") model = model.eval() @@ -122,6 +122,7 @@ def benchmark_inference(args): elif args.mode == "vllm": engine = LLM( model=args.model_path, + tokenizer="hf-internal-testing/llama-tokenizer", max_num_seqs=mbsz, dtype="float16", enforce_eager=True, diff --git a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py b/tests/test_infer/test_ops/triton/test_context_attn_unpad.py index b529e76d1..f2c64d392 100644 --- a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py +++ b/tests/test_infer/test_ops/triton/test_context_attn_unpad.py @@ -100,10 +100,14 @@ def test_context_attention( k_cache_triton = torch.zeros_like(k_cache_ref) v_cache_triton = torch.zeros_like(v_cache_ref) + _, num_heads, head_dim = q_unpad.shape + out_triton = context_attention_unpadded( q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size ) + out_triton = out_triton.view(-1, num_heads, head_dim) + out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads) assert out_torch.shape == out_triton.shape diff --git a/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py b/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py index cc0ef292f..5ce852164 100644 --- a/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py +++ b/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py @@ -3,6 +3,7 @@ import torch import triton from packaging import version from transformers.models.llama.modeling_llama import LlamaRMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm from colossalai.kernel.triton import rms_layernorm from colossalai.testing.utils import parameterize @@ -29,15 +30,28 @@ def test_layer_norm(M, N): x_shape = (M, N) w_shape = (x_shape[-1],) weight = torch.ones(w_shape, dtype=dtype, device="cuda") + residual = torch.rand(x_shape, dtype=dtype, device="cuda") + residual_copy = residual.clone() rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda() x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + x_copy = x.clone() - y_triton = rms_layernorm(x, weight, eps=eps) + y_triton, _ = rms_layernorm(x, weight, eps=eps) y_llama = rms_norm.forward(x).to(dtype) assert y_triton.shape == y_llama.shape assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3) + y_triton, residual = rms_layernorm(x, weight, eps=eps, residual=residual) + + x = x_copy + residual_copy + + y_llama = rms_norm.forward(x).to(dtype) + + assert y_triton.shape == y_llama.shape + assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3) + assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3) + # Triton benchmark plot attributions configs = [ @@ -45,9 +59,19 @@ configs = [ x_names=["SEQUENCE_TOTAL"], x_vals=[i for i in range(128, 1025, 128)], line_arg="provider", - line_vals=["torch_rms_layernorm", "triton_rms_layernorm"], - line_names=["torch_rms_layernorm", "triton_rms_layernorm"], - styles=[("red", "-"), ("blue", "-")], + line_vals=[ + "vllm_rms_layernorm", + "triton_rms_layernorm", + "triton_rms_layernorm_with_residual", + "vllm_rms_layernorm_with_residual", + ], + line_names=[ + "vllm_rms_layernorm", + "triton_rms_layernorm", + "triton_rms_layernorm_with_residual", + "vllm_rms_layernorm_with_residual", + ], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")], ylabel="ms", plot_name=f"RMSNorm benchmarking results", args={"HIDDEN_SIZE": 1024}, @@ -68,13 +92,18 @@ def benchmark_rms_layernorm( eps = 1e-5 x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE) w_shape = (x_shape[-1],) + residual = torch.rand(x_shape, dtype=dtype, device="cuda") weight = torch.ones(w_shape, dtype=dtype, device="cuda") - torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda") + vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda") x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - if provider == "torch_rms_layernorm": - fn = lambda: torch_norm(x) + if provider == "vllm_rms_layernorm": + fn = lambda: vllm_norm(x) elif provider == "triton_rms_layernorm": fn = lambda: rms_layernorm(x, weight, eps=eps) + elif provider == "vllm_rms_layernorm_with_residual": + fn = lambda: vllm_norm(x, residual=residual) + elif provider == "triton_rms_layernorm_with_residual": + fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual) else: raise ValueError("Undefined provider.")