fix the ring attn

This commit is contained in:
wangbluo 2024-09-25 18:34:29 +08:00
parent 10e4f7da72
commit cfd9eda628
2 changed files with 25 additions and 14 deletions

View File

@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Tuple
import torch import torch
import torch.distributed import torch.distributed
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
@ -431,7 +432,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod @staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None): def get_double_ring_groups(sp_group,tp_group, inner_ring_size=None):
""" """
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node. shouldn't be larger than the number of NICs on each node.
@ -443,6 +444,7 @@ class RingAttention(torch.autograd.Function):
""" """
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
tp_size = dist.get_world_size(tp_group)
if inner_ring_size is None: if inner_ring_size is None:
if torch.cuda.device_count() >= dist.get_world_size(): if torch.cuda.device_count() >= dist.get_world_size():
@ -471,18 +473,23 @@ class RingAttention(torch.autograd.Function):
inner_ring_group = None inner_ring_group = None
inter_ring_group = None inter_ring_group = None
world_size = dist.get_world_size()
rank = dist.get_rank()
groups = int(world_size/ sp_size)
# Create inner ring groups # Create inner ring groups
for group_id in range(groups):
for i in range(inner_ring_size): for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) ranks = list(range(i +(group_id*sp_size), (1+group_id)*sp_size, inner_ring_size))
group = dist.new_group(ranks) group = dist.new_group(ranks)
if sp_rank in ranks: if rank in ranks:
inner_ring_group = group inner_ring_group = group
# Create inter ring groups # Create inter ring groups
for group_id in range(groups):
for i in range(num_rings): for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings)) ranks = list(range(i+group_id * num_rings, world_size, sp_size))
group = dist.new_group(ranks) group = dist.new_group(ranks)
if sp_rank in ranks: if rank in ranks:
inter_ring_group = group inter_ring_group = group
return inner_ring_group, inter_ring_group return inner_ring_group, inter_ring_group
@ -493,6 +500,7 @@ class RingAttention(torch.autograd.Function):
k, k,
v, v,
sp_group, sp_group,
tp_group,
attention_mask_type, attention_mask_type,
cu_seqlens=None, cu_seqlens=None,
max_seqlen=None, max_seqlen=None,
@ -537,7 +545,6 @@ class RingAttention(torch.autograd.Function):
RingAttention.ATTN_DONE = torch.cuda.Event() RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None: if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream() RingAttention.SP_STREAM = torch.cuda.Stream()
assert ( assert (
q.shape[2] == k.shape[2] q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\ ), "Q, K and V having different sequence lengths (inference or cross-attn)\
@ -550,7 +557,7 @@ class RingAttention(torch.autograd.Function):
if RingAttention.SP_GROUP is not sp_group: if RingAttention.SP_GROUP is not sp_group:
RingAttention.SP_GROUP = sp_group RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, tp_group, inner_ring_size)
RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group
else: else:
@ -597,6 +604,7 @@ class RingAttention(torch.autograd.Function):
attention_mask_type == AttnMaskType.PADDED_CAUSAL, attention_mask_type == AttnMaskType.PADDED_CAUSAL,
inner_ring_group, inner_ring_group,
inter_ring_group, inter_ring_group,
tp_group,
) )
if attention_mask_type == AttnMaskType.PADDED_CAUSAL: if attention_mask_type == AttnMaskType.PADDED_CAUSAL:
@ -627,6 +635,7 @@ class RingAttention(torch.autograd.Function):
is_packed: Optional[bool] = False, is_packed: Optional[bool] = False,
inner_ring_group: Optional[dist.ProcessGroup] = None, inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None,
tp_group: Optional[dist.ProcessGroup] = None,
): ):
cu_seqlens_q = cu_seqlens_kv = cu_seqlens cu_seqlens_q = cu_seqlens_kv = cu_seqlens
@ -1123,7 +1132,7 @@ class RingAttention(torch.autograd.Function):
if not is_packed: if not is_packed:
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None) return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None)
@staticmethod @staticmethod
def prepare_varlen_batch( def prepare_varlen_batch(

View File

@ -563,12 +563,14 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
tp_group = shard_config.tensor_parallel_process_group
if sp_mode == "ring_attn": if sp_mode == "ring_attn":
attn_output = RingAttention.attention( attn_output = RingAttention.attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
sp_group, sp_group,
tp_group,
**attention_mask, **attention_mask,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
) )