This commit is contained in:
flybird11111 2025-04-24 15:44:20 +08:00
parent 686982764c
commit e891501c55
3 changed files with 9 additions and 8 deletions

View File

@ -1056,8 +1056,8 @@ class HybridParallelPlugin(PipelinePluginBase):
assert ( assert (
not pp_style == "zbv" or scheduler_nodes is not None not pp_style == "zbv" or scheduler_nodes is not None
), f"scheduler_nodes must not be None when using zero bubble pipeline." ), f"scheduler_nodes must not be None when using zero bubble pipeline."
if sp_size is None or sp_size <= 1: # if sp_size is None or sp_size <= 1:
enable_sequence_parallelism = False # enable_sequence_parallelism = False
if enable_sequence_parallelism: if enable_sequence_parallelism:
self.sequence_parallelism_mode = ( self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"

View File

@ -607,7 +607,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) )
else: else:
attn_output = attn_output.reshape(*input_shape, -1).contiguous() # attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)

View File

@ -162,9 +162,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[ [
# Double Ring Attention # Double Ring Attention
{ {
"tp_size": 2, "tp_size": 1,
"pp_size": 1, "pp_size": 1,
"sp_size": 2, "sp_size": 4,
"num_microbatches": 1, "num_microbatches": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn", "sequence_parallelism_mode": "ring_attn",
@ -226,12 +226,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"initial_scale": 1, "initial_scale": 1,
}, },
{ {
"tp_size": 1, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"sp_size": 2, "sp_size": 1,
"num_microbatches": 1, "num_microbatches": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "ring",
"enable_flash_attention": True, "enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 2, "zero_stage": 2,