fix the ring attn

This commit is contained in:
wangbluo 2024-09-25 18:34:29 +08:00
parent 10e4f7da72
commit cfd9eda628
2 changed files with 25 additions and 14 deletions

View File

@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Tuple
import torch
import torch.distributed
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn.functional as F
from einops import rearrange
@ -431,7 +432,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None):
def get_double_ring_groups(sp_group,tp_group, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
@ -443,6 +444,7 @@ class RingAttention(torch.autograd.Function):
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
tp_size = dist.get_world_size(tp_group)
if inner_ring_size is None:
if torch.cuda.device_count() >= dist.get_world_size():
@ -471,19 +473,24 @@ class RingAttention(torch.autograd.Function):
inner_ring_group = None
inter_ring_group = None
world_size = dist.get_world_size()
rank = dist.get_rank()
groups = int(world_size/ sp_size)
# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = dist.new_group(ranks)
if sp_rank in ranks:
inner_ring_group = group
for group_id in range(groups):
for i in range(inner_ring_size):
ranks = list(range(i +(group_id*sp_size), (1+group_id)*sp_size, inner_ring_size))
group = dist.new_group(ranks)
if rank in ranks:
inner_ring_group = group
# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
group = dist.new_group(ranks)
if sp_rank in ranks:
inter_ring_group = group
for group_id in range(groups):
for i in range(num_rings):
ranks = list(range(i+group_id * num_rings, world_size, sp_size))
group = dist.new_group(ranks)
if rank in ranks:
inter_ring_group = group
return inner_ring_group, inter_ring_group
@ -493,6 +500,7 @@ class RingAttention(torch.autograd.Function):
k,
v,
sp_group,
tp_group,
attention_mask_type,
cu_seqlens=None,
max_seqlen=None,
@ -537,7 +545,6 @@ class RingAttention(torch.autograd.Function):
RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream()
assert (
q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\
@ -550,7 +557,7 @@ class RingAttention(torch.autograd.Function):
if RingAttention.SP_GROUP is not sp_group:
RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, tp_group, inner_ring_size)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:
@ -597,6 +604,7 @@ class RingAttention(torch.autograd.Function):
attention_mask_type == AttnMaskType.PADDED_CAUSAL,
inner_ring_group,
inter_ring_group,
tp_group,
)
if attention_mask_type == AttnMaskType.PADDED_CAUSAL:
@ -627,6 +635,7 @@ class RingAttention(torch.autograd.Function):
is_packed: Optional[bool] = False,
inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None,
tp_group: Optional[dist.ProcessGroup] = None,
):
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
@ -1123,7 +1132,7 @@ class RingAttention(torch.autograd.Function):
if not is_packed:
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None)
@staticmethod
def prepare_varlen_batch(

View File

@ -563,12 +563,14 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
tp_group = shard_config.tensor_parallel_process_group
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query_states,
key_states,
value_states,
sp_group,
tp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
)