diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 93538c49a..a4a8c81ae 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1056,8 +1056,8 @@ class HybridParallelPlugin(PipelinePluginBase): assert ( not pp_style == "zbv" or scheduler_nodes is not None ), f"scheduler_nodes must not be None when using zero bubble pipeline." - if sp_size is None or sp_size <= 1: - enable_sequence_parallelism = False + # if sp_size is None or sp_size <= 1: + # enable_sequence_parallelism = False if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index de825606a..ee8cfc80f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -607,7 +607,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication ) else: - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + # attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 13048eae4..b97846408 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -162,9 +162,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, [ # Double Ring Attention { - "tp_size": 2, + "tp_size": 1, "pp_size": 1, - "sp_size": 2, + "sp_size": 4, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", @@ -226,12 +226,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "initial_scale": 1, }, { - "tp_size": 1, + "tp_size": 2, "pp_size": 1, - "sp_size": 2, + "sp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", + "sequence_parallelism_mode": "ring", "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 2,