diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 9625afc1b..7b0c74791 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -73,8 +73,8 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): moe_dp_group: list(filter(is_moe_tensor, model.parameters())), } - if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0: - raise ValueError("No parameters found in dp_process_group or moe_dp_group") + # if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0: + # raise ValueError("No parameters found in dp_process_group or moe_dp_group") super().__init__( model=model, diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 468b890ab..91ba26d17 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -1,7 +1,9 @@ +import math import warnings from typing import List, Optional, Tuple, Union import torch +import torch.distributed import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup @@ -34,6 +36,8 @@ from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group +from ..layer import ColoAttention + # copied from modeling_deepseek.py class AddAuxiliaryLoss(torch.autograd.Function): @@ -529,34 +533,30 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non output_attentions = False - bsz, q_len, _ = hidden_states.size() + bsz, q_len, _ = hidden_states.size() # 1 4, 32 # sp: modify sp_len when sequence parallel mode is ring if sp_mode in ["split_gather", "ring"]: q_len *= sp_size + import torch.distributed as dist - rank = dist.get_rank() - print(f"{rank=}, hidden states:{hidden_states.shape}") + dist.get_rank() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - rank = dist.get_rank() - print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states, sp_group) key_states = all_to_all_comm(key_states, sp_group) value_states = all_to_all_comm(value_states, sp_group) bsz, q_len, _ = query_states.size() - print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -565,7 +565,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0 ) - print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models @@ -573,13 +572,11 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - print( - f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}" - ) - dropout_rate = self.attention_dropout if self.training else 0.0 + + # query_states = query_states.transpose(1, 2) + # key_states = key_states.transpose(1, 2) + # value_states = value_states.transpose(1, 2) + # dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -606,22 +603,57 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate - ) + # attn_output = self._flash_attention_forward( + # query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + # ) + + if shard_config.enable_flash_attention: + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.transpose(1, 2).contiguous() # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + # print(f"{rank=}, shard attn output after all to all:{attn_output[0][0]}") else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) + # print(f"{rank=}, {attn_output[0][0]}") if not output_attentions: attn_weights = None + import torch.distributed as dist + + dist.get_rank() + # print(f"{rank=}, {attn_output[0][0]}") return attn_output, attn_weights, past_key_value return forward @@ -683,24 +715,38 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, + if shard_config.enable_flash_attention: + mask_shape = ( + inputs_embeds.shape[0], + 1, + past_key_values_length + inputs_embeds.shape[1], + past_key_values_length + inputs_embeds.shape[1], + ) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, ) else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) @@ -714,7 +760,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -746,8 +792,10 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si if output_attentions: all_self_attns += (layer_outputs[1],) + # import torch.distributed as dist + # rank = dist.get_rank() + # print(f"{rank=}, {hidden_states[0][0]}") hidden_states = self.norm(hidden_states) - if sp_mode == "ring" or sp_mode == "split_gather": hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all": diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 1e44aba4e..53515ffe4 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -194,11 +193,11 @@ class DeepseekPolicy(Policy): target_key="DeepseekModel", ) - if self.shard_config.enable_flash_attention: - warnings.warn( - "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False." - ) - self.shard_config.enable_flash_attention = False + # if self.shard_config.enable_flash_attention: + # warnings.warn( + # "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False." + # ) + # self.shard_config.enable_flash_attention = False return policy diff --git a/tests/kit/model_zoo/transformers/deepseek.py b/tests/kit/model_zoo/transformers/deepseek.py index b8b446b57..f50996110 100644 --- a/tests/kit/model_zoo/transformers/deepseek.py +++ b/tests/kit/model_zoo/transformers/deepseek.py @@ -59,7 +59,7 @@ def init_deepseek(): num_attention_heads=8, num_key_value_heads=8, # vocab_size=2200, - first_k_dense_replace=1, + first_k_dense_replace=2, attn_implementation="flash_attention_2", torch_dtype="float16", n_routed_experts=8, @@ -68,7 +68,6 @@ def init_deepseek(): if hasattr(config, "pad_token_id"): config.pad_token_id = config.eos_token_id - print(config) model = transformers.AutoModel.from_config(config, trust_remote_code=True) return model diff --git a/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py b/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py index fdca11005..e749af699 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py @@ -30,7 +30,12 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): # TODO: SGD failed for full dp org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD + # model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD + model_fn, + loss_fn, + test_config, + pluggin_cls=MoeHybridParallelPlugin, + optim_class=torch.optim.SGD, ) org_model = org_model.to(torch.float16) @@ -39,16 +44,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - + rank = dist.get_rank() # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model mixtral_model = unwrap_model(org_model, "DeepseekModel", "model") @@ -178,12 +182,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "sp_size": 2, "ep_size": 2, "enable_sequence_parallelism": True, + "enable_flash_attention": True, "sequence_parallelism_mode": "all_to_all", "zero_stage": 1, "overlap_communication": False, "precision": "fp16", "initial_scale": 1, - "find_unused_parameters": True, + # "find_unused_parameters": True, }, # { # "tp_size": 1, @@ -224,7 +229,7 @@ def check_deepseek(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_mixtral(): - spawn(check_deepseek, 4) + spawn(check_deepseek, 2) if __name__ == "__main__":