This commit is contained in:
wangbluo 2024-10-15 11:56:49 +08:00
parent 3dc08c8a5a
commit 6be9862aaf
6 changed files with 23 additions and 6 deletions

View File

@ -1181,6 +1181,7 @@ class HybridParallelPlugin(PipelinePluginBase):
fp8_communication=fp8_communication, fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size, inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh, pg_mesh=self.pg_mesh,
sp_axis=self.sp_axis,
) )
self.amp_config = dict( self.amp_config = dict(

View File

@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod @staticmethod
def get_double_ring_groups(sp_group, pg_mesh, inner_ring_size=None): def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
""" """
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node. shouldn't be larger than the number of NICs on each node.
@ -441,6 +441,9 @@ class RingAttention(torch.autograd.Function):
Returns: Returns:
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
""" """
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
sp_group = pg_mesh.get_group_along_axis(sp_axis)
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
@ -496,6 +499,7 @@ class RingAttention(torch.autograd.Function):
return_softmax=False, return_softmax=False,
inner_ring_size=None, inner_ring_size=None,
pg_mesh=None, pg_mesh=None,
sp_axis=None,
**kwargs, **kwargs,
): ):
""" """
@ -506,7 +510,7 @@ class RingAttention(torch.autograd.Function):
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D] v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism sp_axis (Optional[int]): Sp axis for the global pg mesh.
sp_tream (torch.cuda.Stream): An different stream for output correction. sp_tream (torch.cuda.Stream): An different stream for output correction.
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q. of the sequences in the batch, used to index into q.
@ -539,13 +543,13 @@ class RingAttention(torch.autograd.Function):
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
), f"Mask type {attention_mask_type} is not supported yet." ), f"Mask type {attention_mask_type} is not supported yet."
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
if inner_ring_size != None: if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size)
sp_group, pg_mesh, inner_ring_size
)
RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group
else: else:

View File

@ -869,6 +869,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
scale=scale, scale=scale,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh, pg_mesh=shard_config.pg_mesh,
sp_axis=shard_config.sp_axis,
) )
else: else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

View File

@ -572,6 +572,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
**attention_mask, **attention_mask,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh, pg_mesh=shard_config.pg_mesh,
sp_axis=shard_config.sp_axis,
) )
elif shard_config.enable_flash_attention: elif shard_config.enable_flash_attention:

View File

@ -51,6 +51,7 @@ class ShardConfig:
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention # For ring attention
sp_axis: Optional[int] = None
pg_mesh: Optional[int] = None pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None inner_ring_size: Optional[int] = None
# for moe related # for moe related

View File

@ -24,6 +24,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
sp_group = dist.group.WORLD sp_group = dist.group.WORLD
dp_size, pp_size, tp_size = 1, 1, 1 dp_size, pp_size, tp_size = 1, 1, 1
sp_size = dist.get_world_size() sp_size = dist.get_world_size()
sp_axis = 2
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size) pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
# Some outliers may seem large, but our errors are still lower than # Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's # than Megatron-LM context parallel's
@ -40,7 +41,15 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
# Ring attention vs single GPU # Ring attention vs single GPU
ring_out, ring_lse = RingAttention.attention( ring_out, ring_lse = RingAttention.attention(
q, k, v, sp_group, AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=inner_ring_size, pg_mesh=pg_mesh q,
k,
v,
sp_group,
AttnMaskType.CAUSAL,
return_softmax=True,
inner_ring_size=inner_ring_size,
pg_mesh=pg_mesh,
sp_axis=sp_axis,
) )
ring_out = ring_out.transpose(1, 2) ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func( out, lse, _ = flash_attn_qkvpacked_func(