fix the test

This commit is contained in:
wangbluo 2024-10-11 17:31:40 +08:00
parent 1507a7528f
commit 4e0e99bb6a

View File

@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
@ -17,11 +18,15 @@ from colossalai.utils import get_current_device
@parameterize("nheads", [5]) @parameterize("nheads", [5])
@parameterize("d", [128]) @parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("dtype", [torch.bfloat16, torch.float16])
def check_ring_attn(seq_len, bs, nheads, d, dtype): def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1):
torch.cuda.manual_seed(2) torch.cuda.manual_seed(2)
device = get_current_device() device = get_current_device()
sp_group = dist.group.WORLD sp_axis, tp_axis = 0, 1
sp_size = dist.get_world_size() pg_mesh = ProcessGroupMesh(sp_size, tp_size)
tp_group = pg_mesh.get_group_along_axis(tp_axis)
sp_group = pg_mesh.get_group_along_axis(sp_axis)
# sp_group = dist.group.WORLD
# sp_size = dist.get_world_size()
# Some outliers may seem large, but our errors are still lower than # Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's # than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
@ -44,6 +49,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
AttnMaskType.CAUSAL, AttnMaskType.CAUSAL,
return_softmax=True, return_softmax=True,
inner_ring_size=max(2, sp_size // 2), inner_ring_size=max(2, sp_size // 2),
tp_group=tp_group,
# inner_ring_size=4 # inner_ring_size=4
) )
ring_out = ring_out.transpose(1, 2) ring_out = ring_out.transpose(1, 2)
@ -161,12 +167,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
def launch_single_ring(rank, world_size, port): def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port) colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq() check_packed_seq()
check_ring_attn() check_ring_attn(sp_size=world_size)
def launch_double_ring(rank, world_size, port): def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port) colossalai.launch(rank, world_size, "localhost", port)
check_ring_attn() check_ring_attn(sp_size=4, tp_size=2)
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@ -176,7 +182,7 @@ def test_ring_attn(world_size):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@parameterize("world_size", [4]) @parameterize("world_size", [8])
def test_double_ring(world_size): def test_double_ring(world_size):
spawn(launch_double_ring, nprocs=world_size) spawn(launch_double_ring, nprocs=world_size)