This commit is contained in:
wangbluo
2024-10-15 11:01:34 +08:00
parent 8ff7d0c780
commit 3dc08c8a5a
6 changed files with 18 additions and 15 deletions

View File

@@ -1180,7 +1180,9 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,

View File

@@ -7,7 +7,6 @@ import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from colossalai.cluster import ProcessGroupMesh
from colossalai.kernel.kernel_loader import (
FlashAttentionDaoLoader,
FlashAttentionForFloatAndCustomMaskLoader,
@@ -432,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None):
def get_double_ring_groups(sp_group, pg_mesh, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
@@ -465,20 +464,17 @@ class RingAttention(torch.autograd.Function):
inner_ring_group = None
inter_ring_group = None
inter_axis, inner_axis = 0, 1
pg_mesh = ProcessGroupMesh(num_rings, inner_ring_size)
# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = pg_mesh.get_group_along_axis(inner_axis)
group = pg_mesh.get_group_along_axis(2, ranks)
if sp_rank in ranks:
inner_ring_group = group
# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
group = pg_mesh.get_group_along_axis(inter_axis)
group = pg_mesh.get_group_along_axis(2, ranks)
if sp_rank in ranks:
inter_ring_group = group
@@ -499,6 +495,7 @@ class RingAttention(torch.autograd.Function):
deterministic=False,
return_softmax=False,
inner_ring_size=None,
pg_mesh=None,
**kwargs,
):
"""
@@ -546,7 +543,9 @@ class RingAttention(torch.autograd.Function):
if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(
sp_group, pg_mesh, inner_ring_size
)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:

View File

@@ -868,6 +868,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
dropout_p=dropout_p,
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

View File

@@ -571,6 +571,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
sp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
elif shard_config.enable_flash_attention:

View File

@@ -51,6 +51,7 @@ class ShardConfig:
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention
pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None
# for moe related
moe_dp_group: Optional[ProcessGroup] = None