mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[communication] add p2p_v2.py to support communication with List[Any] (#1407)
* support p2p communication with any type of object | pass test * reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test * [communication] add p2p_v2.py to support communication with List[Any] * Delete _pipeline_schedule_v2.py * Delete test_cifar_with_data_pipeline_tensor_v2.py * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [communication] remove print code * [communication] remove print code
This commit is contained in:
54
tests/test_comm/test_boardcast_send_recv_v2.py
Normal file
54
tests/test_comm/test_boardcast_send_recv_v2.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
|
||||
disable_existing_loggers()
|
||||
world_size = 4
|
||||
CONFIG = dict(parallel=dict(pipeline=world_size))
|
||||
torch.manual_seed(123)
|
||||
|
||||
|
||||
def check_layer(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl', verbose=False)
|
||||
rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if rank == 0:
|
||||
obj = [torch.randn(3,)]
|
||||
_send_object(obj, 1)
|
||||
|
||||
if rank == 1:
|
||||
_recv_object(0)
|
||||
|
||||
if rank == 2:
|
||||
_recv_object(3)
|
||||
|
||||
if rank == 3:
|
||||
obj = [torch.randn(3,)]
|
||||
_send_object(obj, 2)
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_object_list_p2p():
|
||||
disable_existing_loggers()
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_object_list_p2p()
|
132
tests/test_comm/test_object_list_p2p_v2.py
Normal file
132
tests/test_comm/test_object_list_p2p_v2.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.communication.p2p_v2 import send_forward, recv_forward, send_backward, recv_backward, init_process_group
|
||||
from colossalai.context import ParallelMode, Initializer_Pipeline
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
|
||||
disable_existing_loggers()
|
||||
|
||||
# config
|
||||
world_size = 4
|
||||
CONFIG = dict(parallel=dict(pipeline=4))
|
||||
torch.manual_seed(123)
|
||||
use_scatter_gather_tensors = False
|
||||
|
||||
# data
|
||||
torch.manual_seed(123)
|
||||
LIST_LENGTH = 3
|
||||
TENSOR_SIZE = torch.Size((3, 3))
|
||||
TENSOR_SIZE_LIST = [TENSOR_SIZE for i in range(LIST_LENGTH)]
|
||||
data = torch.rand(3, 3)
|
||||
data_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)]
|
||||
grad = torch.rand(3, 3)
|
||||
grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)]
|
||||
|
||||
|
||||
def check_send_recv_forward():
|
||||
disable_existing_loggers()
|
||||
local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if local_rank == 0:
|
||||
device = torch.device('cuda:0')
|
||||
data_to_send = data.to(device)
|
||||
data_list_to_send = []
|
||||
for data_in_list in data_list:
|
||||
data_list_to_send.append(data_in_list.to(device))
|
||||
|
||||
send_forward(data_to_send, scatter_gather_tensors=use_scatter_gather_tensors)
|
||||
send_forward(data_list_to_send, scatter_gather_tensors=use_scatter_gather_tensors)
|
||||
|
||||
elif local_rank == 1:
|
||||
device = torch.device('cuda:1')
|
||||
|
||||
data_recv = recv_forward(TENSOR_SIZE, scatter_gather_tensors=use_scatter_gather_tensors)
|
||||
data_list_recv = recv_forward(TENSOR_SIZE_LIST, scatter_gather_tensors=use_scatter_gather_tensors)
|
||||
|
||||
data_to_check = data.to(device)
|
||||
|
||||
assert data_recv.equal(data_to_check)
|
||||
|
||||
for data_recv, data_send in zip(data_list_recv, data_list):
|
||||
data_to_check = data_send.to(device)
|
||||
data_recv = data_recv.to(device)
|
||||
assert data_recv.equal(data_to_check)
|
||||
|
||||
|
||||
def check_send_recv_backward():
|
||||
disable_existing_loggers()
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
|
||||
device = torch.device('cuda:0')
|
||||
grad_recv = recv_backward(TENSOR_SIZE)
|
||||
grad_list_recv = recv_backward(TENSOR_SIZE_LIST)
|
||||
|
||||
grad_to_check = grad.to(device)
|
||||
grad_recv = grad_recv[0].to(device)
|
||||
|
||||
assert grad_recv.equal(grad_to_check)
|
||||
for grad_recv, grad_send in zip(grad_list_recv, grad_list):
|
||||
grad_recv = grad_recv.to(device)
|
||||
grad_to_check = grad_send.to(device)
|
||||
assert grad_recv.equal(grad_to_check)
|
||||
else:
|
||||
device = torch.device('cuda:1')
|
||||
grad_to_send = grad.to(device)
|
||||
grad_list_to_send = []
|
||||
for grad_in_list in grad_list:
|
||||
grad_list_to_send.append(grad_in_list.to(device))
|
||||
send_backward(grad_to_send)
|
||||
send_backward(grad_list_to_send)
|
||||
|
||||
|
||||
def check_small_pipeline():
|
||||
disable_existing_loggers()
|
||||
# make sure the rank is 4
|
||||
assert gpc.world_size == 4, "make sure to set world size to 4 to start the training process"
|
||||
local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
if local_rank == 0:
|
||||
obj = [1, torch.randn(2, 2).cuda(), None]
|
||||
send_forward(obj)
|
||||
elif local_rank == 1:
|
||||
obj = recv_forward()
|
||||
send_forward(obj)
|
||||
elif local_rank == 2:
|
||||
obj = recv_forward()
|
||||
send_forward(obj)
|
||||
elif local_rank == 3:
|
||||
obj = recv_forward()
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def check_layer(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
disable_existing_loggers()
|
||||
# check_send_recv_forward()
|
||||
check_small_pipeline()
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_object_list_p2p():
|
||||
disable_existing_loggers()
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
disable_existing_loggers()
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
disable_existing_loggers()
|
||||
test_object_list_p2p()
|
Reference in New Issue
Block a user