ColossalAI/tests/test_pipeline/test_p2p_communication.py
Edenzzzz 2a25a2aff7
[Feature] optimize PP overlap (#5735)
* update to fully overlap, still debugging

* improve interface

* fixed deadlock bug

* debug NaN loss

* (experimental) use one comm group for send_fw_recv_fw to fix NaN

* cleaned up interfaces; use one batch p2p for all

* clean up; removed the double p2p batch case

* p2p test passsed

* improve overlap: send fwd before backward

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tentatively use 2 p2p batches

* remove two p2p batches

* fix typos

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove pp.sh

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: root <root@notebook-c55824c0-7742-45e8-9591-c855bb77ad29-0.notebook-c55824c0-7742-45e8-9591-c855bb77ad29.colossal-ai.svc.cluster.local>
2024-06-26 14:48:02 +08:00

84 lines
2.5 KiB
Python

import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
WORLD_SIZE = 2
def check_p2p_communication():
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
stage_manager = PipelineStageManager(pg_mesh, 0)
p2p = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
rank = dist.get_rank()
tensor = torch.ones(1, device=get_accelerator().get_current_device())
data = [
"tensor",
tensor,
[tensor],
{"tensor": tensor},
]
if rank == 0:
for obj in data:
p2p.send_forward(obj)
for i in range(len(data)):
recv_obj, _ = p2p.send_forward_recv_backward(data[i], send_first=False)
assert recv_obj == data[-(i + 1)]
elif rank == 1:
for obj in data:
recv_obj, _ = p2p.recv_forward()
assert recv_obj == obj
for i in range(len(data)):
p2p.send_backward(data[-(i + 1)])
recv_obj, _ = p2p.recv_forward()
assert recv_obj == data[i]
if rank == 1:
for obj in data:
p2p.send_backward(obj)
for i in range(len(data)):
recv_obj, _ = p2p.send_backward_recv_forward(data[i], send_first=True)
assert recv_obj == data[-(i + 1)]
elif rank == 0:
for obj in data:
recv_obj, _ = p2p.recv_backward()
assert recv_obj == obj
for i in range(len(data)):
recv_obj, _ = p2p.send_forward_recv_backward(data[-(i + 1)], send_first=False)
assert recv_obj == data[i]
if rank == 0:
recv_obj, _ = p2p.send_forward_recv_backward(
tensor,
send_metadata=False,
metadata_recv=create_send_metadata(tensor),
)
assert recv_obj == tensor
elif rank == 1:
recv_obj, _ = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_p2p_communication()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pipeline_p2p():
spawn(run_dist, WORLD_SIZE)
if __name__ == "__main__":
test_pipeline_p2p()