diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 5f0e9261c..1f897c1be 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Tuple import torch import torch.distributed import torch.distributed as dist +from torch.distributed import ProcessGroup import torch.nn.functional as F from einops import rearrange @@ -431,7 +432,7 @@ class RingAttention(torch.autograd.Function): INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod - def get_double_ring_groups(sp_group, inner_ring_size=None): + def get_double_ring_groups(sp_group,tp_group, inner_ring_size=None): """ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size shouldn't be larger than the number of NICs on each node. @@ -443,6 +444,7 @@ class RingAttention(torch.autograd.Function): """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) + tp_size = dist.get_world_size(tp_group) if inner_ring_size is None: if torch.cuda.device_count() >= dist.get_world_size(): @@ -471,19 +473,24 @@ class RingAttention(torch.autograd.Function): inner_ring_group = None inter_ring_group = None + world_size = dist.get_world_size() + rank = dist.get_rank() + groups = int(world_size/ sp_size) # Create inner ring groups - for i in range(inner_ring_size): - ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) - group = dist.new_group(ranks) - if sp_rank in ranks: - inner_ring_group = group + for group_id in range(groups): + for i in range(inner_ring_size): + ranks = list(range(i +(group_id*sp_size), (1+group_id)*sp_size, inner_ring_size)) + group = dist.new_group(ranks) + if rank in ranks: + inner_ring_group = group # Create inter ring groups - for i in range(num_rings): - ranks = list(range(i, sp_size, num_rings)) - group = dist.new_group(ranks) - if sp_rank in ranks: - inter_ring_group = group + for group_id in range(groups): + for i in range(num_rings): + ranks = list(range(i+group_id * num_rings, world_size, sp_size)) + group = dist.new_group(ranks) + if rank in ranks: + inter_ring_group = group return inner_ring_group, inter_ring_group @@ -493,6 +500,7 @@ class RingAttention(torch.autograd.Function): k, v, sp_group, + tp_group, attention_mask_type, cu_seqlens=None, max_seqlen=None, @@ -537,7 +545,6 @@ class RingAttention(torch.autograd.Function): RingAttention.ATTN_DONE = torch.cuda.Event() if RingAttention.SP_STREAM is None: RingAttention.SP_STREAM = torch.cuda.Stream() - assert ( q.shape[2] == k.shape[2] ), "Q, K and V having different sequence lengths (inference or cross-attn)\ @@ -550,7 +557,7 @@ class RingAttention(torch.autograd.Function): if RingAttention.SP_GROUP is not sp_group: RingAttention.SP_GROUP = sp_group - inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, tp_group, inner_ring_size) RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group else: @@ -597,6 +604,7 @@ class RingAttention(torch.autograd.Function): attention_mask_type == AttnMaskType.PADDED_CAUSAL, inner_ring_group, inter_ring_group, + tp_group, ) if attention_mask_type == AttnMaskType.PADDED_CAUSAL: @@ -627,6 +635,7 @@ class RingAttention(torch.autograd.Function): is_packed: Optional[bool] = False, inner_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None, + tp_group: Optional[dist.ProcessGroup] = None, ): cu_seqlens_q = cu_seqlens_kv = cu_seqlens @@ -1123,7 +1132,7 @@ class RingAttention(torch.autograd.Function): if not is_packed: dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] - return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None) + return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None) @staticmethod def prepare_varlen_batch( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 47c17e749..fc5bcac6b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -563,12 +563,14 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + tp_group = shard_config.tensor_parallel_process_group if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states, key_states, value_states, sp_group, + tp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, )