diff --git a/colossalai/nn/parallel/layers/module_utils.py b/colossalai/nn/parallel/layers/module_utils.py index 09969b4cc..38d128cc7 100644 --- a/colossalai/nn/parallel/layers/module_utils.py +++ b/colossalai/nn/parallel/layers/module_utils.py @@ -88,7 +88,7 @@ def init_colo_module(module: torch.nn.Module, compute_pattern = compute_spec.compute_pattern if is_colo_module(module): # for each param - # set DistSpec and ComputeSpec + # set its process_group, dist_spec and compute_spec colo_module = get_colo_module(module) colo_module.register(compute_pattern, pg) if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode): @@ -101,6 +101,7 @@ def init_colo_module(module: torch.nn.Module, continue param = module.get_parameter(param_name) if isinstance(param, ColoParameter): + param.set_process_group(pg) param.set_dist_spec(dist_spec) param.compute_spec = compute_spec for mod in param.shared_param_modules: diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 31aedebd3..f5f0b2505 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]: Tensor._base.__get__, Tensor.grad.__get__, Tensor._grad.__get__, - Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor + Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor } @@ -121,11 +121,13 @@ class ColoTensor(torch.Tensor): RuntimeError: """ assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid" - if self.process_group.tp_world_size() != 1: - raise RuntimeError("can not set_process_group on a ColoTensor whose process_group has tp world group") - - if self.dist_spec.placement.value != 'r': - raise RuntimeError("can not set_process_group on a ColoTensor whose dist spec is not REPLICATE") + # if the new pg is the same as the old pg, just returns + if self.process_group == pg: + return + assert self.process_group.tp_world_size() == 1, \ + "Can not set_process_group on a ColoTensor whose process_group has tp world group" + assert self.dist_spec.placement.value == 'r', \ + "Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE" self.process_group = pg @@ -290,17 +292,17 @@ class ColoTensor(torch.Tensor): def is_replicate(self): return self.dist_spec.placement == DistPlacementPattern.REPLICATE \ - or (len(self.dist_spec.num_partitions) == 1 - and self.dist_spec.num_partitions[0] == 1) \ - or (self.process_group.tp_world_size() == 1) + or (len(self.dist_spec.num_partitions) == 1 + and self.dist_spec.num_partitions[0] == 1) \ + or (self.process_group.tp_world_size() == 1) def is_shard_1dcol(self): return self.dist_spec.placement == DistPlacementPattern.SHARD \ - and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1 + and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1 def is_shard_1drow(self): return self.dist_spec.placement == DistPlacementPattern.SHARD \ - and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 + and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 def is_sharded(self): return self.dist_spec.placement == DistPlacementPattern.SHARD diff --git a/tests/test_tensor/test_module_spec.py b/tests/test_tensor/test_module_spec.py index a33af9c3c..b51d9df42 100644 --- a/tests/test_tensor/test_module_spec.py +++ b/tests/test_tensor/test_module_spec.py @@ -1,11 +1,11 @@ -from copy import copy +from copy import deepcopy import pytest from functools import partial import torch import torch.multiprocessing as mp -from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec +from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec from colossalai.nn.parallel.layers import init_colo_module, check_colo_module from _utils import tensor_equal, tensor_shard_equal, set_seed @@ -112,21 +112,25 @@ def run_linear_with_spec(mode): with ColoInitContext(device=get_current_device()): model = torch.nn.Linear(4, 8) - model_handy = copy(model) + model_handy = deepcopy(model) world_size = torch.distributed.get_world_size() pg = ProcessGroup(tp_degree=world_size) compute_spec = ComputeSpec(ComputePattern.TP1D) init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode) x = torch.rand(2, 4).cuda() + colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg)) + out = model(x) - colo_out = model_handy(x) + colo_out = model_handy(colo_x) assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) out.backward(grad) colo_out.backward(grad) - assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad, pg.tp_local_rank(), pg.tp_world_size()) - assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad, pg.tp_local_rank(), pg.tp_world_size()) + + assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size()) + assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size()) def run_check_shared_param(): @@ -196,7 +200,7 @@ def run_dist_check(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) -@pytest.mark.skip("under development lazy init ColoParameter in Context") +@pytest.mark.skip("for higher testing speed") @rerun_if_address_is_in_use() def test_module_linear_1d(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) @@ -205,7 +209,7 @@ def test_module_linear_1d(world_size): @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) -@pytest.mark.skip("under development lazy init ColoParameter in Context") +@pytest.mark.skip("for higher testing speed") @rerun_if_address_is_in_use() def test_module_model(world_size): run_func = partial(run_dist_model, world_size=world_size, port=free_port()) @@ -214,7 +218,7 @@ def test_module_model(world_size): @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 2]) -@pytest.mark.skip("under development lazy init ColoParameter in Context") +@pytest.mark.skip("for higher testing speed") @rerun_if_address_is_in_use() def test_module_check(world_size): run_func = partial(run_dist_check, world_size=world_size, port=free_port()) @@ -222,4 +226,4 @@ def test_module_check(world_size): if __name__ == '__main__': - test_module_check(2) + test_module_linear_1d(4)