diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 1addea1d4..553c89018 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -115,8 +115,9 @@ class InferenceEngine: tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. Returns: - nn.Module: _description_ + nn.Module: The model optimized by Shardformer. """ + shardconfig = ShardConfig( tensor_parallel_process_group=tp_group, pipeline_stage_manager=stage_manager, @@ -149,25 +150,25 @@ class InferenceEngine: Returns: List[str]: Inference result returned by one generation. """ + with torch.inference_mode(): + self.generation_config = generation_config + if prompts is not None or prompts_token_ids is not None: + self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids) - self.generation_config = generation_config - if prompts is not None or prompts_token_ids is not None: - self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids) + output_seqs_list = [] + output_tokens_list = [] - output_seqs_list = [] - output_tokens_list = [] + while self.request_handler.check_unfinished_seqs(): + output_seqs_list += self.step() - while self.request_handler.check_unfinished_seqs(): - output_seqs_list += self.step() + output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) - output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) + for seq in output_seqs_list: + output_tokens_list.append(seq.input_token_id + seq.output_token_id) - for seq in output_seqs_list: - output_tokens_list.append(seq.input_token_id + seq.output_token_id) + output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True) - output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True) - - return output_str + return output_str def add_request( self, diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py index e4dd02b60..43ccdc430 100644 --- a/colossalai/inference/modeling/layers/attention.py +++ b/colossalai/inference/modeling/layers/attention.py @@ -6,7 +6,6 @@ import torch.nn.functional as F from transformers.modeling_attn_mask_utils import AttentionMaskConverter -@torch.no_grad def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): """ Func: copy key/value into key/value cache. @@ -41,7 +40,6 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): return cache -@torch.no_grad def convert_kvcache(cache, lengths, block_tables, pad_id=0): """ Func: convert key/value cache for calculation @@ -81,7 +79,6 @@ class PagedAttention: """ @staticmethod - @torch.no_grad def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size): """ Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size] @@ -97,14 +94,12 @@ class PagedAttention: return padded_tensor @staticmethod - @torch.no_grad def generate_padding_mask(lengths, max_seq_len): range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len) padding_mask = range_tensor < lengths.unsqueeze(1) return padding_mask @staticmethod - @torch.no_grad def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor: """ Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). @@ -122,7 +117,6 @@ class PagedAttention: return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim) @staticmethod - @torch.no_grad def nopad_context_forward( q: torch.Tensor, # [num_tokens, num_heads, head_size] k: torch.Tensor, # [num_tokens, num_kv_heads, head_size] @@ -191,7 +185,6 @@ class PagedAttention: return attn_output @staticmethod - @torch.no_grad def pad_context_forward( q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size] @@ -249,7 +242,6 @@ class PagedAttention: return attn_output @staticmethod - @torch.no_grad def pad_decoding_forward( q: torch.Tensor, # [bsz, 1, num_heads, head_size] k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size] @@ -306,7 +298,6 @@ class PagedAttention: return attn_output @staticmethod - @torch.no_grad def no_pad_decoding_forward( self, q: torch.Tensor, # [num_tokens, num_heads, head_size] diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 3fadb1905..355140bc1 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -32,7 +32,6 @@ except ImportError: logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.") -@torch.no_grad() def llama_causal_lm_forward( self: LlamaForCausalLM, batch: BatchInfo = None, @@ -58,7 +57,6 @@ def llama_causal_lm_forward( return logits -@torch.no_grad() def llama_model_forward( self: LlamaModel, batch: BatchInfo = None, @@ -120,7 +118,6 @@ def llama_model_forward( return hidden_states -@torch.no_grad() def llama_decoder_layer_forward( self: LlamaDecoderLayer, hidden_states: torch.Tensor, @@ -139,7 +136,7 @@ def llama_decoder_layer_forward( """This function will replace the forward function of LlamaDecoderLayer. Args: - hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`. + hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], storing mapping of token_position_id -> block_id. Defaults to None. k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. @@ -154,8 +151,8 @@ def llama_decoder_layer_forward( norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. """ - residual = hidden_states + residual = hidden_states hidden_states = self.input_layernorm(hidden_states, norm_output) # Self Attention hidden_states = self.self_attn( @@ -240,7 +237,6 @@ class NopadLlamaAttention(LlamaAttention): return attn_layer # Replace transformers.models.llama.modeling_llama.LlamaAttention.forward - @torch.no_grad() def forward( self, hidden_states: torch.Tensor, @@ -258,8 +254,8 @@ class NopadLlamaAttention(LlamaAttention): ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Args: - hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)` - residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj. + hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. + residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj. block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], storing mapping of token_position_id -> block_id. Defaults to None. k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. @@ -321,7 +317,7 @@ class NopadLlamaAttention(LlamaAttention): sm_scale=sm_scale, ) - attn_output = attn_output.reshape(-1, self.hidden_size) + attn_output = attn_output.view(-1, self.hidden_size) attn_output = torch.addmm(residual, attn_output, self.o_proj.weight) return attn_output @@ -345,9 +341,10 @@ class NopadLlamaMLP(LlamaMLP): mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None. """ super().__init__(config) - self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False) - self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False) + self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False) self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False) + self.gate_proj = None + self.up_proj = None @staticmethod def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP: @@ -371,15 +368,14 @@ class NopadLlamaMLP(LlamaMLP): return mlp_layer - @torch.no_grad() def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: """ Args: - hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`. - residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj. + hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. + residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj. """ - gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight) - act_out = torch.nn.functional.silu(gate_proj_out, inplace=True) - up_proj_out = torch.mm(hidden_states, self.up_proj.weight) - tmp_out = act_out * up_proj_out + hidden_states = hidden_states.expand(2, -1, -1) + gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) + act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True) + tmp_out = act_out * gate_up_proj_out[1] return torch.addmm(residual, tmp_out, self.down_proj.weight) diff --git a/colossalai/inference/modeling/models/padding_llama.py b/colossalai/inference/modeling/models/padding_llama.py new file mode 100644 index 000000000..2eac07d76 --- /dev/null +++ b/colossalai/inference/modeling/models/padding_llama.py @@ -0,0 +1,450 @@ +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py +from typing import List, Optional, Tuple + +import torch +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, +) + +from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.modeling.layers.attention import PagedAttention +from colossalai.inference.struct import BatchInfo +from colossalai.kernel.triton import ( + context_attention_unpadded, + copy_kv_to_blocked_cache, + flash_decoding_attention, + get_xine_cache, + rotary_embedding, +) +from colossalai.logging import get_dist_logger + +from flash_attn.bert_padding import index_first_axis, pad_input # noqa + +logger = get_dist_logger(__name__) + +try: + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.") + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def llama_causal_lm_forward( + self: LlamaForCausalLM, + batch: BatchInfo = None, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, +): + """This function will replace the forward function of LlamaForCausalLM. + + Args: + batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None. + k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. + v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + """ + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + hidden_states = llama_model_forward( + self.model, + batch=batch, + k_caches=k_caches, + v_caches=v_caches, + ) + logits = self.lm_head(hidden_states) + return logits + + +def llama_model_forward( + self: LlamaModel, + batch: BatchInfo = None, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, +): + """This function will replace the forward function of LlamaModel. + + Args: + batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None. + k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. + v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + """ + input_ids = batch.get_batch_inputs() + block_tables = batch.get_block_table_tensor() + attention_mask = batch.get_attn_mask() + + if attention_mask is not None: + if HAS_TRITON: + sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32) + else: + sequence_lengths = batch.get_sequence_lengths() + else: + sequence_lengths = batch.get_sequence_lengths() + + batch_size, _ = input_ids.shape + kv_seq_len = sequence_lengths.max().item() + + if attention_mask is not None: + if batch.is_prompts: + # Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer. + position_ids = generate_padding_position_id(attention_mask) + else: + position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1) + else: + if batch.is_prompts: + position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=batch.device) + position_ids = position_ids.unsqueeze(0) + else: + position_ids = torch.arange(kv_seq_len - 1, kv_seq_len, dtype=torch.long, device=batch.device) + position_ids = position_ids.unsqueeze(0) + + hidden_states = self.embed_tokens(input_ids) + + cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts) + + if batch.is_prompts: + output_tensor = torch.zeros( + (sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device + ) + else: + output_tensor = torch.zeros( + (batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device + ) + sm_scale = 1.0 / (batch.head_dim**0.5) + + norm_output = torch.empty_like(hidden_states) + + for layer_id, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + position_ids=position_ids, + block_tables=block_tables, + k_cache=k_caches[layer_id], + v_cache=v_caches[layer_id], + is_prompts=batch.is_prompts, + sequence_lengths=sequence_lengths, + attention_mask=attention_mask, + kv_seq_len=kv_seq_len, + cos_sin=cos_sin, + fd_inter_tensor=batch.fd_inter_tensor, + output_tensor=output_tensor, + norm_output=norm_output, + sm_scale=sm_scale, + ) + + if batch.is_prompts: + hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous() + norm_output = torch.empty_like(hidden_states) + hidden_states = self.norm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) + + return hidden_states + + +def llama_decoder_layer_forward( + self: LlamaDecoderLayer, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + block_tables: torch.Tensor = None, + k_cache: torch.Tensor = None, + v_cache: torch.Tensor = None, + is_prompts: bool = True, + sequence_lengths: torch.Tensor = None, + attention_mask: torch.Tensor = None, + kv_seq_len: int = 0, + cos_sin: Tuple[torch.Tensor] = None, + fd_inter_tensor: FDIntermTensors = None, + output_tensor: torch.Tensor = None, + norm_output: torch.Tensor = None, + sm_scale: int = None, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """This function will replace the forward function of LlamaDecoderLayer. + + Args: + hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. + position_ids (torch.LongTensor), The position ids of input sequences. + block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. Defaults to None. + k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. + cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. + fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None. + output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. + sm_scale (int, optional): Used for flash attention. Defaults to None. + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + block_tables=block_tables, + k_cache=k_cache, + v_cache=v_cache, + is_prompts=is_prompts, + sequence_lengths=sequence_lengths, + attention_mask=attention_mask, + kv_seq_len=kv_seq_len, + cos_sin=cos_sin, + fd_inter_tensor=fd_inter_tensor, + output_tensor=output_tensor, + sm_scale=sm_scale, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class PadLlamaAttention(LlamaAttention): + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + attn_qproj_w: torch.nn.Parameter = None, + attn_kproj_w: torch.nn.Parameter = None, + attn_vproj_w: torch.nn.Parameter = None, + attn_oproj_w: torch.nn.Parameter = None, + ): + """This layer will replace the LlamaAttention. + + Args: + config (LlamaConfig): Holding the Llama model config. + layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None. + attn_qproj_w (torch.nn.Parameter, optional): The q_proj weight. Defaults to None. + attn_kproj_w (torch.nn.Parameter, optional): The k_proj weight. Defaults to None. + attn_vproj_w (torch.nn.Parameter, optional): The v_proj weight. Defaults to None. + attn_oproj_w (torch.nn.Parameter, optional): The o_proj weight. Defaults to None. + """ + super().__init__(config, layer_idx) + self.q_proj.weight = attn_qproj_w + self.k_proj.weight = attn_kproj_w + self.v_proj.weight = attn_vproj_w + self.o_proj.weight = attn_oproj_w + + @staticmethod + def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention: + """Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention + + Args: + module (LlamaAttention): The origin LlamaAttention layer. + """ + config = module.config + layer_idx = module.layer_idx + + attn_qproj_w = module.q_proj.weight + attn_kproj_w = module.k_proj.weight + attn_vproj_w = module.v_proj.weight + attn_oproj_w = module.o_proj.weight + + attn_layer = PadLlamaAttention( + config=config, + layer_idx=layer_idx, + attn_qproj_w=attn_qproj_w, + attn_kproj_w=attn_kproj_w, + attn_vproj_w=attn_vproj_w, + attn_oproj_w=attn_oproj_w, + ) + + return attn_layer + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + block_tables: torch.Tensor = None, + k_cache: torch.Tensor = None, + v_cache: torch.Tensor = None, + is_prompts: bool = True, + sequence_lengths: torch.Tensor = None, + attention_mask: torch.Tensor = None, + kv_seq_len: int = 0, + cos_sin: Tuple[torch.Tensor] = None, + fd_inter_tensor: FDIntermTensors = None, + output_tensor: torch.Tensor = None, + sm_scale: int = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim] + position_ids (torch.LongTensor), The position ids of input sequences. + block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. Defaults to None. + k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. + attention_mask (torch.Tensor, optional): The padding mask - corresponds to a tensor of size [batch_size, seq_len] + where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. + cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. + fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for + storing intermediate values in flash-decoding. Defaults to None. + output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + sm_scale (int, optional): Used for flash attention. Defaults to None. + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + if HAS_TRITON: + if is_prompts: + if attention_mask is not None: + query_states, key_states, value_states, indices = unpading_input( + query_states, key_states, value_states, attention_mask + ) + else: + query_states = query_states.view(-1, self.num_heads, self.head_dim) + key_states = key_states.view(-1, self.num_heads, self.head_dim) + value_states = value_states.view(-1, self.num_heads, self.head_dim) + else: + query_states = query_states.squeeze(dim=1) + key_states = key_states.squeeze(dim=1) + value_states = value_states.squeeze(dim=1) + + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + + block_size = k_cache.size(-2) + + if is_prompts: + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + ) + if attention_mask is not None: + attn_output = pad_input(attn_output, indices, bsz, q_len) + else: + copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + attn_output = flash_decoding_attention( + q=query_states, + k_cache=k_cache, + v_cache=v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + max_seq_len_in_batch=kv_seq_len, + output=output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + sm_scale=sm_scale, + ) + attn_output = attn_output.squeeze(1) + else: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if is_prompts: + attn_output = PagedAttention.pad_context_forward( + query_states, + key_states, + value_states, + k_cache, + v_cache, + sequence_lengths, + block_tables, + attention_mask, + ) + else: + attn_output = PagedAttention.pad_decoding_forward( + query_states, + key_states, + value_states, + k_cache, + v_cache, + sequence_lengths, + block_tables, + attention_mask, + ) + + attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output + + +def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor: + """Generate padding position_id through attention mask. + + Args: + attention_mask (`torch.Tensor` of shape [batch_size, sequence_length]: + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + Returns: + torch.Tensor: The padding position_id. + """ + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + return position_ids + + +def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor): + """Convert padding input to nopad input. + + Args: + q (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim] + k (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim] + v (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim] + attention_mask (torch.Tensor): [batch_size, sequence_length] + + Returns: + Tuple[torch.Tensor]: The unpad q, k, v and The index of valid data in each batch. + + """ + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape + q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) + k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) + v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) + return (q, k, v, indices) diff --git a/colossalai/inference/sampler.py b/colossalai/inference/sampler.py index 93e55fcf3..7547c32b0 100644 --- a/colossalai/inference/sampler.py +++ b/colossalai/inference/sampler.py @@ -10,7 +10,7 @@ def greedy_sample( """ Sample tokens greedyly. """ - results = torch.argmax(logprobs, dim=-1).cpu() + results = torch.argmax(logprobs, dim=-1) return results diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index 37fcd504c..07351d023 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -220,7 +220,7 @@ def flash_decoding_attention( num_kv_group (int, optional): Number of key/value groups. Defaults to 1. Returns: - Output tensor with shape [bsz, num_heads, q_len, head_dim] + Output tensor with shape [bsz, num_heads, head_dim] """ q = q.squeeze() if q.dim() == 4 else q assert q.dim() == 3, f"Incompatible q dim: {q.dim()}" @@ -261,6 +261,8 @@ def flash_decoding_attention( # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV)) + output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output + _flash_decoding_fwd_kernel[grid]( q, k_cache, @@ -292,9 +294,7 @@ def flash_decoding_attention( BLOCK_SIZE=block_size, HEAD_DIM=head_dim, ) - - output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output - + grid = (triton.next_power_of_2(bsz), num_heads) _flash_decoding_fwd_reduce_kernel[grid]( diff --git a/colossalai/kernel/triton/fused_rotary_embedding.py b/colossalai/kernel/triton/fused_rotary_embedding.py index 237b088a4..cf2a70f7b 100644 --- a/colossalai/kernel/triton/fused_rotary_embedding.py +++ b/colossalai/kernel/triton/fused_rotary_embedding.py @@ -117,7 +117,6 @@ def fused_rotary_emb( ) -@torch.no_grad() def fused_rotary_embedding( q: torch.Tensor, k: torch.Tensor, diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index 89bd40b40..9194319d5 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -274,7 +274,6 @@ def fused_rotary_embedding_kernel( ) -@torch.no_grad() def rotary_embedding( q: torch.Tensor, k: torch.Tensor, diff --git a/colossalai/kernel/triton/rms_layernorm.py b/colossalai/kernel/triton/rms_layernorm.py index e4424eb33..fb4fa02bc 100644 --- a/colossalai/kernel/triton/rms_layernorm.py +++ b/colossalai/kernel/triton/rms_layernorm.py @@ -49,7 +49,6 @@ if HAS_TRITON: # Write output tl.store(Y + cols, y.to(tl.float16), mask=mask) - @torch.no_grad() def rms_layernorm(x, weight, eps, norm_output=None): # allocate output y = torch.empty_like(x) if norm_output is None else norm_output diff --git a/colossalai/kernel/triton/rotary_cache_copy.py b/colossalai/kernel/triton/rotary_cache_copy.py index 6b064ed4a..48dc7de43 100644 --- a/colossalai/kernel/triton/rotary_cache_copy.py +++ b/colossalai/kernel/triton/rotary_cache_copy.py @@ -77,7 +77,6 @@ def decoding_cache_kernel( ) -@torch.no_grad() def get_xine_cache(lengths: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, is_prompts: bool = False): """ Transform cos/sin cache into no pad sequence, with two different modes.