diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index aec8e84a0..798fca88f 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -858,7 +858,6 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group - if sp_mode == "ring_attn": attn_output = RingAttention.attention( query, diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 61ea0677f..bcb2c1f8a 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -17,7 +17,7 @@ from colossalai.utils import get_current_device @parameterize("nheads", [5]) @parameterize("d", [128]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1): +def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size): torch.cuda.manual_seed(2) device = get_current_device() sp_group = dist.group.WORLD @@ -165,7 +165,7 @@ def launch_single_ring(rank, world_size, port): def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - check_ring_attn(sp_size=4, tp_size=2) + check_ring_attn(sp_size=world_size) @rerun_if_address_is_in_use() @@ -175,7 +175,7 @@ def test_ring_attn(world_size): @rerun_if_address_is_in_use() -@parameterize("world_size", [8]) +@parameterize("world_size", [4]) def test_double_ring(world_size): spawn(launch_double_ring, nprocs=world_size)