diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index d1ad84604..fe102eecf 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch import torch.distributed -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -141,7 +140,9 @@ class LlamaPipelineForwards: invert=(sp_mode != "ring_attn"), ) else: - attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attn_kwargs: torch.Tensor = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values + ) # Support SP + PP. Later stages have already received the split input. split_input = disable_pp or stage_manager.is_first_stage() @@ -177,6 +178,7 @@ class LlamaPipelineForwards: all_self_attns = () if output_attentions else None next_decoder_cache = None start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1]) + position_embeddings = self.rotary_emb(hidden_states, position_ids) num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: @@ -204,6 +206,7 @@ class LlamaPipelineForwards: output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -214,6 +217,7 @@ class LlamaPipelineForwards: output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] @@ -486,8 +490,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[Union[torch.Tensor, Dict]] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, @@ -505,30 +509,14 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s ) bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] # sp: modify sp_len when sequence parallel mode is ring if is_share_sp_tp(sp_mode): q_len *= sp_size - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": @@ -537,9 +525,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -552,7 +540,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -610,17 +598,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication ) else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights return forward diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index e8f9471f9..9ad63dd7f 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -33,22 +33,9 @@ class LlamaPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaFlashAttention2, - LlamaModel, - LlamaSdpaAttention, - ) + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel - ATTN_IMPLEMENTATION = { - "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, - "sdpa": LlamaSdpaAttention, - } policy = {} - - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -82,7 +69,7 @@ class LlamaPolicy(Policy): num_kv_heads //= sp_size decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads - policy[attn_cls] = ModulePolicyDescription( + policy[LlamaAttention] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: @@ -91,7 +78,7 @@ class LlamaPolicy(Policy): "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, - target_key=attn_cls, + target_key=LlamaAttention, ) if self.pipeline_stage_manager is None: @@ -354,6 +341,7 @@ class LlamaPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers))