From ced6b5e1c317cbf26914b1551145c19776e0589b Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 16 May 2025 11:39:50 +0800 Subject: [PATCH] fix --- colossalai/shardformer/modeling/command.py | 8 ++++++-- colossalai/shardformer/policies/command.py | 15 ++++++++------- .../test_model/test_shard_command.py | 16 +++++++++++----- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index fe494b996..aa11b7043 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -123,6 +123,7 @@ class CommandPipelineForwards: # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage + shard_config.enable_flash_attention = True if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length, seq_length_with_past) @@ -391,6 +392,8 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = None + shard_config.enable_flash_attention = True + if shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) @@ -402,7 +405,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - dropout = (0.0 if not self.training else self.attention_dropout,) + dropout = 0.0 if not self.training else self.attention_dropout attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -477,6 +480,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode if position_ids is None: position_ids = cache_position.unsqueeze(0) + shard_config.enable_flash_attention = True + # in this case, attention_mask is a dict rather than a tensor if shard_config.enable_flash_attention: mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) @@ -524,7 +529,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode cache_position, position_embeddings, ) - else: layer_outputs = decoder_layer( hidden_states, diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 3cba8ded0..8a510307e 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -75,14 +75,15 @@ class CommandPolicy(Policy): policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) + + self.append_or_create_method_replacement( + description={ + "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, + ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: - self.append_or_create_method_replacement( - description={ - "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), - }, - policy=policy, - target_key=attn_cls, - ) if self.pipeline_stage_manager is None: self.append_or_create_method_replacement( description={ diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 384943675..4d156d84d 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -213,6 +213,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 1, "pp_size": 1, @@ -220,7 +231,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 2, "precision": "fp16", @@ -231,7 +241,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 2, "enable_all_optimization": True, - "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, @@ -243,7 +252,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 4, "use_lazy_init": False, - "enable_flash_attention": True, "precision": "fp32", "enable_gradient_checkpointing": True, "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), @@ -252,7 +260,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "enable_all_optimization": True, - "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 2, "precision": "fp16", @@ -263,7 +270,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 2, "enable_all_optimization": True, - "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16",