[Feature] optimize PP overlap (#5735)

* update to fully overlap, still debugging

* improve interface

* fixed deadlock bug

* debug NaN loss

* (experimental) use one comm group for send_fw_recv_fw to fix NaN

* cleaned up interfaces; use one batch p2p for all

* clean up; removed the double p2p batch case

* p2p test passsed

* improve overlap: send fwd before backward

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tentatively use 2 p2p batches

* remove two p2p batches

* fix typos

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove pp.sh

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: root <root@notebook-c55824c0-7742-45e8-9591-c855bb77ad29-0.notebook-c55824c0-7742-45e8-9591-c855bb77ad29.colossal-ai.svc.cluster.local>
This commit is contained in:
Edenzzzz
2024-06-26 14:48:02 +08:00
committed by GitHub
parent 4ccaaaab63
commit 2a25a2aff7
9 changed files with 457 additions and 358 deletions

View File

@@ -15,8 +15,7 @@ WORLD_SIZE = 2
def check_p2p_communication():
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
stage_manager = PipelineStageManager(pg_mesh, 0)
p2p = PipelineP2PCommunication(stage_manager)
p2p = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
rank = dist.get_rank()
tensor = torch.ones(1, device=get_accelerator().get_current_device())
@@ -31,41 +30,40 @@ def check_p2p_communication():
for obj in data:
p2p.send_forward(obj)
for i in range(len(data)):
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
recv_obj, _ = p2p.send_forward_recv_backward(data[i], send_first=False)
assert recv_obj == data[-(i + 1)]
elif rank == 1:
for obj in data:
recv_obj = p2p.recv_forward()
recv_obj, _ = p2p.recv_forward()
assert recv_obj == obj
for i in range(len(data)):
p2p.send_backward(data[-(i + 1)])
recv_obj = p2p.recv_forward()
recv_obj, _ = p2p.recv_forward()
assert recv_obj == data[i]
if rank == 1:
for obj in data:
p2p.send_backward(obj)
for i in range(len(data)):
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
recv_obj, _ = p2p.send_backward_recv_forward(data[i], send_first=True)
assert recv_obj == data[-(i + 1)]
elif rank == 0:
for obj in data:
recv_obj = p2p.recv_backward()
recv_obj, _ = p2p.recv_backward()
assert recv_obj == obj
for i in range(len(data)):
recv_obj = p2p.recv_backward()
p2p.send_forward(data[-(i + 1)])
recv_obj, _ = p2p.send_forward_recv_backward(data[-(i + 1)], send_first=False)
assert recv_obj == data[i]
if rank == 0:
recv_obj = p2p.send_forward_recv_backward(
recv_obj, _ = p2p.send_forward_recv_backward(
tensor,
send_metadata=False,
metadata_recv=create_send_metadata(tensor),
)
assert recv_obj == tensor
elif rank == 1:
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
recv_obj, _ = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False)

View File

@@ -52,7 +52,7 @@ def check_stage_manager():
# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:
group = stage_manager.get_p2p_process_group(prev, cur)
group = stage_manager.get_p2p_process_group()
dist.barrier(group=group)
# check stage groups