mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Feature] optimize PP overlap (#5735)
* update to fully overlap, still debugging * improve interface * fixed deadlock bug * debug NaN loss * (experimental) use one comm group for send_fw_recv_fw to fix NaN * cleaned up interfaces; use one batch p2p for all * clean up; removed the double p2p batch case * p2p test passsed * improve overlap: send fwd before backward * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tentatively use 2 p2p batches * remove two p2p batches * fix typos * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove pp.sh --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root <root@notebook-c55824c0-7742-45e8-9591-c855bb77ad29-0.notebook-c55824c0-7742-45e8-9591-c855bb77ad29.colossal-ai.svc.cluster.local>
This commit is contained in:
@@ -15,8 +15,7 @@ WORLD_SIZE = 2
|
||||
def check_p2p_communication():
|
||||
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, 0)
|
||||
p2p = PipelineP2PCommunication(stage_manager)
|
||||
|
||||
p2p = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
|
||||
rank = dist.get_rank()
|
||||
|
||||
tensor = torch.ones(1, device=get_accelerator().get_current_device())
|
||||
@@ -31,41 +30,40 @@ def check_p2p_communication():
|
||||
for obj in data:
|
||||
p2p.send_forward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
|
||||
recv_obj, _ = p2p.send_forward_recv_backward(data[i], send_first=False)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 1:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_forward()
|
||||
recv_obj, _ = p2p.recv_forward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
p2p.send_backward(data[-(i + 1)])
|
||||
recv_obj = p2p.recv_forward()
|
||||
recv_obj, _ = p2p.recv_forward()
|
||||
assert recv_obj == data[i]
|
||||
|
||||
if rank == 1:
|
||||
for obj in data:
|
||||
p2p.send_backward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
|
||||
recv_obj, _ = p2p.send_backward_recv_forward(data[i], send_first=True)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 0:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_backward()
|
||||
recv_obj, _ = p2p.recv_backward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.recv_backward()
|
||||
p2p.send_forward(data[-(i + 1)])
|
||||
recv_obj, _ = p2p.send_forward_recv_backward(data[-(i + 1)], send_first=False)
|
||||
assert recv_obj == data[i]
|
||||
|
||||
if rank == 0:
|
||||
recv_obj = p2p.send_forward_recv_backward(
|
||||
recv_obj, _ = p2p.send_forward_recv_backward(
|
||||
tensor,
|
||||
send_metadata=False,
|
||||
metadata_recv=create_send_metadata(tensor),
|
||||
)
|
||||
assert recv_obj == tensor
|
||||
elif rank == 1:
|
||||
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
|
||||
recv_obj, _ = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
|
||||
assert recv_obj == tensor
|
||||
p2p.send_backward(tensor, send_metadata=False)
|
||||
|
||||
|
@@ -52,7 +52,7 @@ def check_stage_manager():
|
||||
# check p2p groups
|
||||
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
|
||||
if rank in [prev, cur]:
|
||||
group = stage_manager.get_p2p_process_group(prev, cur)
|
||||
group = stage_manager.get_p2p_process_group()
|
||||
dist.barrier(group=group)
|
||||
|
||||
# check stage groups
|
||||
|
Reference in New Issue
Block a user