mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
[communication] add p2p_v2.py to support communication with List[Any] (#1407)
* support p2p communication with any type of object | pass test * reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test * [communication] add p2p_v2.py to support communication with List[Any] * Delete _pipeline_schedule_v2.py * Delete test_cifar_with_data_pipeline_tensor_v2.py * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [communication] remove print code * [communication] remove print code
This commit is contained in:
parent
1590f59908
commit
44fd3c83ab
268
colossalai/communication/p2p_v2.py
Normal file
268
colossalai/communication/p2p_v2.py
Normal file
@ -0,0 +1,268 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import List, Tuple, Union, Any
|
||||
import pickle
|
||||
import io
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import distributed_c10d as c10d
|
||||
from torch.distributed import ProcessGroupNCCL
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
_pg_manager = {}
|
||||
_unpickler = pickle.Unpickler
|
||||
|
||||
|
||||
def init_process_group():
|
||||
"""intialise process group by dist.new_group in the adjacent stages
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
world_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
for i in range(world_size - 1):
|
||||
_pg_manager[(i, i + 1)] = dist.new_group([i, i + 1])
|
||||
|
||||
|
||||
def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGroupNCCL:
|
||||
"""get the group handle of two given ranks
|
||||
|
||||
Args:
|
||||
first_rank (int): first rank in the pair
|
||||
second_rank (int): second rank in the pair
|
||||
|
||||
Returns:
|
||||
:class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks
|
||||
"""
|
||||
if len(_pg_manager) == 0:
|
||||
init_process_group()
|
||||
if first_rank > second_rank:
|
||||
first_rank, second_rank = second_rank, first_rank
|
||||
pair_key = (first_rank, second_rank)
|
||||
return _pg_manager[pair_key]
|
||||
|
||||
|
||||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
|
||||
"""transform tensor to object with unpickle.
|
||||
Info of the device in bytes stream will be modified into current device before unpickling
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.tensor`): tensor to be unpickled
|
||||
tensor_size (:class:`torch.Size`): Size of the real info in bytes
|
||||
|
||||
Returns:
|
||||
Any: object after unpickled
|
||||
"""
|
||||
buf = tensor.numpy().tobytes()[:tensor_size]
|
||||
if b'cuda' in buf:
|
||||
buf_array = bytearray(buf)
|
||||
device_index = torch.cuda.current_device()
|
||||
buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
|
||||
buf = bytes(buf_array)
|
||||
|
||||
io_bytes = io.BytesIO(buf)
|
||||
byte_pickler = _unpickler(io_bytes)
|
||||
unpickle = byte_pickler.load()
|
||||
|
||||
return unpickle
|
||||
|
||||
|
||||
def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=None):
|
||||
"""This is a modified version of the broadcast_object_list in torch.distribution
|
||||
The only difference is that object will be move to correct device after unpickled.
|
||||
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
||||
be updated with data sent from rank src.
|
||||
|
||||
Args:
|
||||
object_list (List[Any]): list of object to broadcast
|
||||
src (int): source rank to broadcast
|
||||
dst (int): dst rank to broadcast
|
||||
device (:class:`torch.device`): device to do broadcast. current device in default
|
||||
|
||||
"""
|
||||
group = _acquire_pair_group_handle(src, dst)
|
||||
|
||||
if c10d._rank_not_in_group(group):
|
||||
c10d._warn_not_in_group("broadcast_object_list")
|
||||
return
|
||||
|
||||
local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
if local_rank == src:
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
|
||||
object_sizes_tensor = torch.cat(size_list)
|
||||
else:
|
||||
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
|
||||
|
||||
is_nccl_backend = c10d._check_for_nccl_backend(group)
|
||||
current_device = None
|
||||
|
||||
if device is not None:
|
||||
if is_nccl_backend and device.type != "cuda":
|
||||
raise ValueError("device type must be cuda for nccl backend")
|
||||
current_device = device
|
||||
else:
|
||||
current_device = torch.device("cpu")
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
if is_nccl_backend:
|
||||
object_sizes_tensor = object_sizes_tensor.to(current_device)
|
||||
|
||||
# Broadcast object sizes
|
||||
c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False)
|
||||
|
||||
# Concatenate and broadcast serialized object tensors
|
||||
if local_rank == src:
|
||||
object_tensor = torch.cat(tensor_list)
|
||||
else:
|
||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
if is_nccl_backend:
|
||||
object_tensor = object_tensor.to(current_device)
|
||||
|
||||
c10d.broadcast(object_tensor, src=src, group=group, async_op=False)
|
||||
|
||||
# Deserialize objects using their stored sizes.
|
||||
offset = 0
|
||||
|
||||
if local_rank != src:
|
||||
for i, obj_size in enumerate(object_sizes_tensor):
|
||||
obj_view = object_tensor[offset:offset + obj_size]
|
||||
obj_view = obj_view.type(torch.uint8)
|
||||
if obj_view.device != torch.device("cpu"):
|
||||
obj_view = obj_view.cpu()
|
||||
offset += obj_size
|
||||
# unpickle
|
||||
unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)
|
||||
|
||||
# unconsistence in device
|
||||
if isinstance(unpickle_object,
|
||||
torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
|
||||
object_list[i] = unpickle_object
|
||||
|
||||
|
||||
def _send_object(object: Any, dst: int) -> None:
|
||||
"""send anything to dst rank
|
||||
Args:
|
||||
object (Any): object needed to be sent
|
||||
dst (int): rank of the destination
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
# handler = _acquire_pair_group_handle(local_rank, dst)
|
||||
|
||||
# transform to list if not
|
||||
if isinstance(object, torch.Tensor):
|
||||
object = [object]
|
||||
|
||||
# broadcast length first
|
||||
# TODO : more elegant ? P.S. reduce a _broadcast_object_list
|
||||
_broadcast_object_list([len(object)], local_rank, dst)
|
||||
# then broadcast safely
|
||||
_broadcast_object_list(object, local_rank, dst)
|
||||
|
||||
|
||||
def _recv_object(src: int) -> Any:
|
||||
"""recv anything from src
|
||||
|
||||
Args:
|
||||
src (int): source rank of data. local rank will receive data from src rank.
|
||||
|
||||
Returns:
|
||||
Any: Object received from src.
|
||||
"""
|
||||
local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
# handler = _acquire_pair_group_handle(local_rank, src)
|
||||
# recv length first
|
||||
length = [0]
|
||||
_broadcast_object_list(length, src, local_rank)
|
||||
|
||||
# then create recv buff from length[0] and broadcast
|
||||
object = [None] * length[0]
|
||||
_broadcast_object_list(object, src, local_rank)
|
||||
|
||||
if length[0] == 1:
|
||||
object = object[0]
|
||||
|
||||
return object
|
||||
|
||||
|
||||
def recv_forward(prev_rank: int = None) -> Any:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
|
||||
prev_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if gpc.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
input_tensor = _recv_object(prev_rank)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(next_rank: int = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
|
||||
next_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradident tensor list.
|
||||
"""
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
output_tensor_grad = _recv_object(next_rank)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward(output_object: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
||||
Args:
|
||||
output_object Any: Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not gpc.is_pipeline_last_stage():
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
_send_object(output_object, next_rank)
|
||||
|
||||
|
||||
def send_backward(input_object: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
_send_object(input_object, prev_rank)
|
54
tests/test_comm/test_boardcast_send_recv_v2.py
Normal file
54
tests/test_comm/test_boardcast_send_recv_v2.py
Normal file
@ -0,0 +1,54 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
|
||||
disable_existing_loggers()
|
||||
world_size = 4
|
||||
CONFIG = dict(parallel=dict(pipeline=world_size))
|
||||
torch.manual_seed(123)
|
||||
|
||||
|
||||
def check_layer(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl', verbose=False)
|
||||
rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if rank == 0:
|
||||
obj = [torch.randn(3,)]
|
||||
_send_object(obj, 1)
|
||||
|
||||
if rank == 1:
|
||||
_recv_object(0)
|
||||
|
||||
if rank == 2:
|
||||
_recv_object(3)
|
||||
|
||||
if rank == 3:
|
||||
obj = [torch.randn(3,)]
|
||||
_send_object(obj, 2)
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_object_list_p2p():
|
||||
disable_existing_loggers()
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_object_list_p2p()
|
132
tests/test_comm/test_object_list_p2p_v2.py
Normal file
132
tests/test_comm/test_object_list_p2p_v2.py
Normal file
@ -0,0 +1,132 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.communication.p2p_v2 import send_forward, recv_forward, send_backward, recv_backward, init_process_group
|
||||
from colossalai.context import ParallelMode, Initializer_Pipeline
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
|
||||
disable_existing_loggers()
|
||||
|
||||
# config
|
||||
world_size = 4
|
||||
CONFIG = dict(parallel=dict(pipeline=4))
|
||||
torch.manual_seed(123)
|
||||
use_scatter_gather_tensors = False
|
||||
|
||||
# data
|
||||
torch.manual_seed(123)
|
||||
LIST_LENGTH = 3
|
||||
TENSOR_SIZE = torch.Size((3, 3))
|
||||
TENSOR_SIZE_LIST = [TENSOR_SIZE for i in range(LIST_LENGTH)]
|
||||
data = torch.rand(3, 3)
|
||||
data_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)]
|
||||
grad = torch.rand(3, 3)
|
||||
grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)]
|
||||
|
||||
|
||||
def check_send_recv_forward():
|
||||
disable_existing_loggers()
|
||||
local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if local_rank == 0:
|
||||
device = torch.device('cuda:0')
|
||||
data_to_send = data.to(device)
|
||||
data_list_to_send = []
|
||||
for data_in_list in data_list:
|
||||
data_list_to_send.append(data_in_list.to(device))
|
||||
|
||||
send_forward(data_to_send, scatter_gather_tensors=use_scatter_gather_tensors)
|
||||
send_forward(data_list_to_send, scatter_gather_tensors=use_scatter_gather_tensors)
|
||||
|
||||
elif local_rank == 1:
|
||||
device = torch.device('cuda:1')
|
||||
|
||||
data_recv = recv_forward(TENSOR_SIZE, scatter_gather_tensors=use_scatter_gather_tensors)
|
||||
data_list_recv = recv_forward(TENSOR_SIZE_LIST, scatter_gather_tensors=use_scatter_gather_tensors)
|
||||
|
||||
data_to_check = data.to(device)
|
||||
|
||||
assert data_recv.equal(data_to_check)
|
||||
|
||||
for data_recv, data_send in zip(data_list_recv, data_list):
|
||||
data_to_check = data_send.to(device)
|
||||
data_recv = data_recv.to(device)
|
||||
assert data_recv.equal(data_to_check)
|
||||
|
||||
|
||||
def check_send_recv_backward():
|
||||
disable_existing_loggers()
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
|
||||
device = torch.device('cuda:0')
|
||||
grad_recv = recv_backward(TENSOR_SIZE)
|
||||
grad_list_recv = recv_backward(TENSOR_SIZE_LIST)
|
||||
|
||||
grad_to_check = grad.to(device)
|
||||
grad_recv = grad_recv[0].to(device)
|
||||
|
||||
assert grad_recv.equal(grad_to_check)
|
||||
for grad_recv, grad_send in zip(grad_list_recv, grad_list):
|
||||
grad_recv = grad_recv.to(device)
|
||||
grad_to_check = grad_send.to(device)
|
||||
assert grad_recv.equal(grad_to_check)
|
||||
else:
|
||||
device = torch.device('cuda:1')
|
||||
grad_to_send = grad.to(device)
|
||||
grad_list_to_send = []
|
||||
for grad_in_list in grad_list:
|
||||
grad_list_to_send.append(grad_in_list.to(device))
|
||||
send_backward(grad_to_send)
|
||||
send_backward(grad_list_to_send)
|
||||
|
||||
|
||||
def check_small_pipeline():
|
||||
disable_existing_loggers()
|
||||
# make sure the rank is 4
|
||||
assert gpc.world_size == 4, "make sure to set world size to 4 to start the training process"
|
||||
local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
if local_rank == 0:
|
||||
obj = [1, torch.randn(2, 2).cuda(), None]
|
||||
send_forward(obj)
|
||||
elif local_rank == 1:
|
||||
obj = recv_forward()
|
||||
send_forward(obj)
|
||||
elif local_rank == 2:
|
||||
obj = recv_forward()
|
||||
send_forward(obj)
|
||||
elif local_rank == 3:
|
||||
obj = recv_forward()
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def check_layer(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
disable_existing_loggers()
|
||||
# check_send_recv_forward()
|
||||
check_small_pipeline()
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_object_list_p2p():
|
||||
disable_existing_loggers()
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
disable_existing_loggers()
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
disable_existing_loggers()
|
||||
test_object_list_p2p()
|
Loading…
Reference in New Issue
Block a user