diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1e0f7be24..93538c49a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1056,6 +1056,8 @@ class HybridParallelPlugin(PipelinePluginBase): assert ( not pp_style == "zbv" or scheduler_nodes is not None ), f"scheduler_nodes must not be None when using zero bubble pipeline." + if sp_size is None or sp_size <= 1: + enable_sequence_parallelism = False if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index d1ad84604..de825606a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -94,6 +94,7 @@ class LlamaPipelineForwards: batch_size, seq_length = input_shape device = hidden_states.device + # Support SP + PP sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group @@ -112,6 +113,7 @@ class LlamaPipelineForwards: raise ValueError("cache_position is a required argument when using StaticCache.") cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) + seq_length_with_past = seq_length + past_seen_tokens if output_attentions: @@ -141,7 +143,7 @@ class LlamaPipelineForwards: invert=(sp_mode != "ring_attn"), ) else: - attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values) # Support SP + PP. Later stages have already received the split input. split_input = disable_pp or stage_manager.is_first_stage() @@ -177,6 +179,7 @@ class LlamaPipelineForwards: all_self_attns = () if output_attentions else None next_decoder_cache = None start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1]) + position_embeddings = self.rotary_emb(hidden_states, position_ids) num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: @@ -204,6 +207,7 @@ class LlamaPipelineForwards: output_attentions, use_cache, cache_position, + position_embeddings ) else: layer_outputs = decoder_layer( @@ -214,6 +218,7 @@ class LlamaPipelineForwards: output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings ) hidden_states = layer_outputs[0] @@ -486,8 +491,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[Union[torch.Tensor, Dict]] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, @@ -505,30 +510,21 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s ) bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] # sp: modify sp_len when sequence parallel mode is ring if is_share_sp_tp(sp_mode): q_len *= sp_size - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + # if sp_mode == "all_to_all": + # # query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + # # key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + # # value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) + # # bsz, q_len, _ = query_states.size() + # # hidden_states = all_to_all_comm(hidden_states, sp_group, fp8_communication=shard_config.fp8_communication) - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": @@ -537,9 +533,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -552,7 +548,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) + # cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -610,17 +607,13 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication ) else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + # return attn_output, attn_weights, past_key_value + return attn_output, attn_weights return forward diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index e8f9471f9..ae718dd94 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -36,19 +36,19 @@ class LlamaPolicy(Policy): from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, - LlamaFlashAttention2, + # LlamaFlashAttention2, LlamaModel, - LlamaSdpaAttention, + # LlamaSdpaAttention, ) - ATTN_IMPLEMENTATION = { - "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, - "sdpa": LlamaSdpaAttention, - } + # ATTN_IMPLEMENTATION = { + # "eager": LlamaAttention, + # "flash_attention_2": LlamaFlashAttention2, + # "sdpa": LlamaSdpaAttention, + # } policy = {} - - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + attn_cls = LlamaAttention + # attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -354,6 +354,7 @@ class LlamaPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers)) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index ee2f1f405..f3997a158 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -225,6 +225,7 @@ class ModelSharder(object): """ if self.shard_config and self.shard_config.pipeline_stage_manager: held_layers = self.policy.get_held_layers() + print("held_layers", held_layers) set_tensors_to_none(self.model, exclude=set(held_layers)) return set(self._get_recursive_held_layers(held_layers)) return None diff --git a/colossalai/shardformer/shard/utils.py b/colossalai/shardformer/shard/utils.py index 2bac37bfe..5ae7e9de7 100644 --- a/colossalai/shardformer/shard/utils.py +++ b/colossalai/shardformer/shard/utils.py @@ -16,4 +16,6 @@ def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> No for n, p in model.named_parameters(recurse=False): setattr(model, n, None) for n, buf in model.named_buffers(recurse=False): + import torch + print("buffer", n, torch.distributed.get_rank()) setattr(model, n, None) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index b97846408..13048eae4 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -162,9 +162,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, [ # Double Ring Attention { - "tp_size": 1, + "tp_size": 2, "pp_size": 1, - "sp_size": 4, + "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", @@ -226,12 +226,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "initial_scale": 1, }, { - "tp_size": 2, + "tp_size": 1, "pp_size": 1, - "sp_size": 1, + "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", + "sequence_parallelism_mode": "all_to_all", "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 2,