This commit is contained in:
wangbluo 2024-10-14 18:07:56 +08:00
parent 3201377e94
commit fe9208feac

View File

@ -623,7 +623,6 @@ class RingAttention(torch.autograd.Function):
is_packed: Optional[bool] = False, is_packed: Optional[bool] = False,
inner_ring_group: Optional[dist.ProcessGroup] = None, inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None,
tp_group: Optional[dist.ProcessGroup] = None,
): ):
cu_seqlens_q = cu_seqlens_kv = cu_seqlens cu_seqlens_q = cu_seqlens_kv = cu_seqlens
@ -1120,7 +1119,7 @@ class RingAttention(torch.autograd.Function):
if not is_packed: if not is_packed:
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None) return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)
@staticmethod @staticmethod
def prepare_varlen_batch( def prepare_varlen_batch(