From 703bb5c18dd3a701209017de828e9cdc8ca58356 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 11 Oct 2024 17:34:20 +0800 Subject: [PATCH] fix the test --- tests/test_shardformer/test_layer/test_ring_attn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 38dfcb239..b9291a061 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -25,8 +25,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1): pg_mesh = ProcessGroupMesh(sp_size, tp_size) tp_group = pg_mesh.get_group_along_axis(tp_axis) sp_group = pg_mesh.get_group_along_axis(sp_axis) - # sp_group = dist.group.WORLD - # sp_size = dist.get_world_size() # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) @@ -50,7 +48,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1): return_softmax=True, inner_ring_size=max(2, sp_size // 2), tp_group=tp_group, - # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func(