diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py index 6269429f8..95c9b0471 100644 --- a/colossalai/zero/sharded_param/__init__.py +++ b/colossalai/zero/sharded_param/__init__.py @@ -1,4 +1,4 @@ -from colossalai.zero.sharded_param.sharded_param import ShardedParam from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from colossalai.zero.sharded_param.sharded_param import ShardedParam, ShardedParamV2 -__all__ = ['ShardedParam', 'ShardedTensor'] +__all__ = ['ShardedParam', 'ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 472487d6a..61e9d9d32 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -6,6 +6,40 @@ import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.zero.sharded_model._zero3_utils import get_shard +from colossalai.zero.sharded_param import ShardedTensor +from typing import Union, Tuple, Optional +import numpy + + + +class ShardedParamV2(object): + + def __init__(self, param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None) -> None: + self._data_sharded_tensor = ShardedTensor(param.data, process_group) + if param.requires_grad and param.grad is not None: + self._grad_sharded_tensor = ShardedTensor(param.grad, process_group) + param.grad = None + else: + self._grad_sharded_tensor = None + + # make sure the shared param is the only owner of payload + param.data = torch.empty([], dtype=param.dtype, device=param.device) + + @property + def data(self): + return self._data_sharded_tensor.payload + + @data.setter + def data(self, t: torch.Tensor): + self._data_sharded_tensor.payload = t + + @property + def grad(self): + return self._grad_sharded_tensor.payload + + @grad.setter + def grad(self, t: torch.Tensor): + self._grad_sharded_tensor.payload = t class ShardedParam(object): diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index 876dd4953..640292f31 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -1,9 +1,11 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from copy import deepcopy from functools import partial import colossalai +from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 import pytest import torch import torch.multiprocessing as mp @@ -11,7 +13,7 @@ from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.sharded_param import ShardedTensor, ShardedParam from colossalai.utils import free_port from colossalai.logging import get_dist_logger, disable_existing_loggers -from tests.test_zero_data_parallel.common import Net, CONFIG +from tests.test_zero_data_parallel.common import Net, CONFIG, allclose def run_shard_tensor(rank, world_size, port): @@ -36,28 +38,33 @@ def test_shard_tensor(): mp.spawn(run_func, nprocs=world_size) -def run_init_shard_param(rank, world_size, port): +def _run_shard_param_v2(rank, world_size, port): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - param = torch.nn.Parameter(data=torch.rand(2, 3)) - sparam = ShardedParam(param, None, True) - payload = sparam.payload(torch.device('cuda')) - assert (list(payload.shape) == [3]) - del sparam - param_shape = (2, 3) - sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu')) - payload = sparam.payload(torch.device('cuda')) - assert (list(payload.shape) == [3]) + param = torch.nn.Parameter(torch.randn(2, 3)) + param_ref = deepcopy(param) + sparam = ShardedParamV2(param=param, process_group=None) - param_shape = (2, 3) - sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu')) - payload = sparam.payload(torch.device('cuda')) - assert (list(payload.shape) == [2, 3]) + allclose(sparam.data, param_ref.data) + assert (param.data.numel() == 1) -def run_shard_param_check(rank, world_size, port): +@pytest.mark.dist +def test_shard_param_v2(): + world_size = 2 + run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +def _run_test_shard_param(rank, world_size, port): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + param = torch.nn.Parameter(torch.randn(2, 3)) + param_ref = deepcopy(param) + sparam = ShardedParamV2(param=param, process_group=None) + print(sparam.data) + print(param_ref.data) + logger = get_dist_logger() model = Net() @@ -77,12 +84,31 @@ def run_shard_param_check(rank, world_size, port): @pytest.mark.dist -def test_shard_shape(): +def test_shard_param(): world_size = 2 - run_func = partial(run_shard_param_check, world_size=world_size, port=free_port()) + run_func = partial(_run_test_shard_param, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) +def run_init_shard_param(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + param = torch.nn.Parameter(data=torch.rand(2, 3)) + sparam = ShardedParam(param, None, True) + payload = sparam.payload(torch.device('cuda')) + assert (list(payload.shape) == [3]) + del sparam + + param_shape = (2, 3) + sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu')) + payload = sparam.payload(torch.device('cuda')) + assert (list(payload.shape) == [3]) + + param_shape = (2, 3) + sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu')) + payload = sparam.payload(torch.device('cuda')) + assert (list(payload.shape) == [2, 3]) + + @pytest.mark.dist def test_init_shard_param(): world_size = 2 @@ -92,5 +118,6 @@ def test_init_shard_param(): if __name__ == '__main__': test_shard_tensor() - test_shard_shape() + test_shard_param() + test_shard_param_v2() test_init_shard_param()