mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 11:58:06 +00:00
overlap kv comm with output rescale (#6017)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
parent
26493b97d3
commit
f1c3266a94
@ -690,6 +690,13 @@ class RingAttention(torch.autograd.Function):
|
|||||||
)
|
)
|
||||||
return out, softmax_lse, rng_state
|
return out, softmax_lse, rng_state
|
||||||
|
|
||||||
|
def _kv_comm(i):
|
||||||
|
# Avoid overwriting attn input when it shares mem with buffer
|
||||||
|
if not RingAttention.ATTN_DONE.query():
|
||||||
|
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
||||||
|
if i < local_sp_size - 1:
|
||||||
|
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||||
|
|
||||||
def _local_ring_forward():
|
def _local_ring_forward():
|
||||||
# (Hopefully) overlap output correction with next flash attn
|
# (Hopefully) overlap output correction with next flash attn
|
||||||
for i in range(local_sp_size):
|
for i in range(local_sp_size):
|
||||||
@ -698,12 +705,8 @@ class RingAttention(torch.autograd.Function):
|
|||||||
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
||||||
if i > 0:
|
if i > 0:
|
||||||
local_kv_comms[(i + 1) % 2].wait()
|
local_kv_comms[(i + 1) % 2].wait()
|
||||||
|
if i == 0:
|
||||||
# Avoid overwriting attn input when it shares mem with buffer
|
_kv_comm(i)
|
||||||
if not RingAttention.ATTN_DONE.query():
|
|
||||||
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
|
||||||
if i < local_sp_size - 1:
|
|
||||||
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# Compute with local KV; no mask
|
# Compute with local KV; no mask
|
||||||
@ -734,6 +737,9 @@ class RingAttention(torch.autograd.Function):
|
|||||||
rng_states[i],
|
rng_states[i],
|
||||||
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
||||||
RingAttention.ATTN_DONE.record()
|
RingAttention.ATTN_DONE.record()
|
||||||
|
# Pipeline the next KV comm with output correction instead of the next flash attn
|
||||||
|
# to minimize idle time when comm takes longer than attn.
|
||||||
|
_kv_comm(i + 1)
|
||||||
|
|
||||||
block_softmax_lse[i % 2] = (
|
block_softmax_lse[i % 2] = (
|
||||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||||
@ -761,15 +767,13 @@ class RingAttention(torch.autograd.Function):
|
|||||||
# all new KVs from the previous inner ring
|
# all new KVs from the previous inner ring
|
||||||
for i in range(local_sp_size):
|
for i in range(local_sp_size):
|
||||||
with torch.cuda.stream(sp_streams[i % 2]):
|
with torch.cuda.stream(sp_streams[i % 2]):
|
||||||
if not RingAttention.ATTN_DONE.query():
|
|
||||||
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
|
||||||
if i < local_sp_size - 1:
|
|
||||||
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
|
||||||
|
|
||||||
# Send & recv KV
|
# Send & recv KV
|
||||||
if i > 0:
|
if i > 0:
|
||||||
local_kv_comms[(i + 1) % 2].wait()
|
local_kv_comms[(i + 1) % 2].wait()
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
_kv_comm(i)
|
||||||
|
|
||||||
if ring_num_idx > inter_ring_rank:
|
if ring_num_idx > inter_ring_rank:
|
||||||
kv_block = kv_buffers[i % 2]
|
kv_block = kv_buffers[i % 2]
|
||||||
(
|
(
|
||||||
@ -778,6 +782,8 @@ class RingAttention(torch.autograd.Function):
|
|||||||
rng_states[i + local_sp_size * ring_num_idx],
|
rng_states[i + local_sp_size * ring_num_idx],
|
||||||
) = _forward(q1, kv_block[0], kv_block[1], causal=False)
|
) = _forward(q1, kv_block[0], kv_block[1], causal=False)
|
||||||
RingAttention.ATTN_DONE.record()
|
RingAttention.ATTN_DONE.record()
|
||||||
|
|
||||||
|
_kv_comm(i + 1)
|
||||||
block_softmax_lse[i % 2] = (
|
block_softmax_lse[i % 2] = (
|
||||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||||
)
|
)
|
||||||
@ -792,6 +798,8 @@ class RingAttention(torch.autograd.Function):
|
|||||||
rng_states[i + local_sp_size * ring_num_idx],
|
rng_states[i + local_sp_size * ring_num_idx],
|
||||||
) = _forward(q, kv_block[0], kv_block[1], causal=False)
|
) = _forward(q, kv_block[0], kv_block[1], causal=False)
|
||||||
RingAttention.ATTN_DONE.record()
|
RingAttention.ATTN_DONE.record()
|
||||||
|
|
||||||
|
_kv_comm(i + 1)
|
||||||
block_softmax_lse[i % 2] = (
|
block_softmax_lse[i % 2] = (
|
||||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user