diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py index 527cf11d6..6269429f8 100644 --- a/colossalai/zero/sharded_param/__init__.py +++ b/colossalai/zero/sharded_param/__init__.py @@ -1,3 +1,4 @@ -from .sharded_param import ShardedParam +from colossalai.zero.sharded_param.sharded_param import ShardedParam +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -__all__ = ['ShardedParam'] +__all__ = ['ShardedParam', 'ShardedTensor'] diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index f7363d0a5..56f2382af 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -56,6 +56,13 @@ class ShardedParam(object): """ return self._param_payload.to(target_device) + def set_payload(self, data: torch.Tensor): + r""" + set payload as data + """ + assert self._param_payload.numel() == data.numel() + self._param_payload.copy_(data) + def shard(self): r""" Distributed the payload of param to all processes. diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py new file mode 100644 index 000000000..19e2715d6 --- /dev/null +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -0,0 +1,67 @@ +import torch +import torch.distributed as dist +from colossalai.zero.sharded_model._zero3_utils import get_shard +from typing import Optional + + +class ShardedTensor(object): + + def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None: + r""" + A tensor sharded in multiple processes. + """ + self._payload = tensor + self.process_group = process_group + self.world_size = dist.get_world_size(self.process_group) + self.local_rank = dist.get_rank(self.process_group) + self._is_sharded = False + self._payload = tensor + + self._origin_shape = tensor.shape + self._origin_numel = tensor.numel() + self._origin_dtype = tensor.dtype + + @property + def is_sharded(self): + return self._is_sharded + + @property + def payload(self): + return self._payload + + @payload.setter + def payload(self, tensor): + self._payload.copy_(tensor) + + @property + def dtype(self): + return self._origin_dtype + + @property + def shape(self): + return self._payload.shape + + def shard(self): + if self._is_sharded: + return + self._payload, _ = get_shard(self._payload, self.local_rank, self.world_size) + self._is_sharded = True + + def gather(self): + if not self._is_sharded: + return + + buffer_list = [] + payload_numel = self._payload.numel() + for i in range(self.world_size): + if i == self.local_rank: + buffer_list.append(self._payload.cuda()) + else: + buffer_list.append(torch.zeros(payload_numel).cuda()) + + torch.distributed.all_gather(buffer_list, + buffer_list[self.local_rank], + group=self.process_group, + async_op=False) + self._payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape) + self._is_sharded = False diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index 642cd7f2b..4f6eb52b2 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -7,12 +7,38 @@ import colossalai import pytest import torch import torch.multiprocessing as mp -from colossalai.zero.sharded_param import ShardedParam +from colossalai.zero.sharded_param import ShardedTensor, ShardedParam from colossalai.utils import free_port from colossalai.logging import get_dist_logger, disable_existing_loggers from tests.test_zero_data_parallel.common import Net, CONFIG +def run_shard_tensor(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) + + assert list(t.shape) == [world_size * 2, 3] + t.shard() + # The shape is flattened + assert list(t.shape) == [6] + # Do nothing + t.shard() + assert list(t.shape) == [6] + + t.gather() + assert list(t.shape) == [world_size * 2, 3] + + t.payload = torch.zeros(world_size * 2, 3) + assert torch.sum(t.payload).cpu() == 0 + + +@pytest.mark.dist +def test_shard_tensor(): + world_size = 2 + run_func = partial(run_shard_tensor, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + def run_init_shard_param(rank, world_size, port): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') param = torch.nn.Parameter(data=torch.rand(2, 3)) @@ -68,5 +94,6 @@ def test_init_shard_param(): if __name__ == '__main__': + test_shard_tensor() test_shard_shape() test_init_shard_param()