mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 19:55:29 +00:00
fix the test
This commit is contained in:
parent
4e0e99bb6a
commit
703bb5c18d
@ -25,8 +25,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1):
|
|||||||
pg_mesh = ProcessGroupMesh(sp_size, tp_size)
|
pg_mesh = ProcessGroupMesh(sp_size, tp_size)
|
||||||
tp_group = pg_mesh.get_group_along_axis(tp_axis)
|
tp_group = pg_mesh.get_group_along_axis(tp_axis)
|
||||||
sp_group = pg_mesh.get_group_along_axis(sp_axis)
|
sp_group = pg_mesh.get_group_along_axis(sp_axis)
|
||||||
# sp_group = dist.group.WORLD
|
|
||||||
# sp_size = dist.get_world_size()
|
|
||||||
# Some outliers may seem large, but our errors are still lower than
|
# Some outliers may seem large, but our errors are still lower than
|
||||||
# than Megatron-LM context parallel's
|
# than Megatron-LM context parallel's
|
||||||
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
|
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
|
||||||
@ -50,7 +48,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1):
|
|||||||
return_softmax=True,
|
return_softmax=True,
|
||||||
inner_ring_size=max(2, sp_size // 2),
|
inner_ring_size=max(2, sp_size // 2),
|
||||||
tp_group=tp_group,
|
tp_group=tp_group,
|
||||||
# inner_ring_size=4
|
|
||||||
)
|
)
|
||||||
ring_out = ring_out.transpose(1, 2)
|
ring_out = ring_out.transpose(1, 2)
|
||||||
out, lse, _ = flash_attn_qkvpacked_func(
|
out, lse, _ = flash_attn_qkvpacked_func(
|
||||||
|
Loading…
Reference in New Issue
Block a user