This commit is contained in:
wangbluo 2025-05-16 11:39:50 +08:00
parent 10bc6af2b1
commit ced6b5e1c3
3 changed files with 25 additions and 14 deletions

View File

@ -123,6 +123,7 @@ class CommandPipelineForwards:
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
shard_config.enable_flash_attention = True
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
@ -391,6 +392,8 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = None
shard_config.enable_flash_attention = True
if shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
@ -402,7 +405,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
dropout = (0.0 if not self.training else self.attention_dropout,)
dropout = 0.0 if not self.training else self.attention_dropout
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
@ -477,6 +480,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
shard_config.enable_flash_attention = True
# in this case, attention_mask is a dict rather than a tensor
if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
@ -524,7 +529,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,

View File

@ -75,14 +75,15 @@ class CommandPolicy(Policy):
policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
self.append_or_create_method_replacement(
description={
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement(
description={
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={

View File

@ -213,6 +213,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
@ -220,7 +231,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
@ -231,7 +241,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"enable_flash_attention": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
@ -243,7 +252,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2,
"num_microbatches": 4,
"use_lazy_init": False,
"enable_flash_attention": True,
"precision": "fp32",
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
@ -252,7 +260,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
@ -263,7 +270,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",