From 1aad903c1537eafb73fac1729b6df30b7006312f Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 12 Jul 2022 10:24:05 +0800 Subject: [PATCH] [tensor] redistribute among different process groups (#1247) * make it faster * [tensor] rename convert_to_dist -> redistribute * [tensor] ShardSpec and ReplicaSpec * [tensor] redistribute among diff pgs * polish code --- colossalai/nn/_ops/addmm.py | 1 - colossalai/nn/_ops/embedding.py | 2 -- colossalai/nn/_ops/linear.py | 2 -- colossalai/tensor/colo_tensor.py | 24 ++++++++++++++++++++++-- colossalai/tensor/process_group.py | 8 ++++---- tests/test_tensor/test_module_spec.py | 2 +- tests/test_tensor/test_op.py | 1 - tests/test_tensor/test_tensor.py | 25 +++++++++++++++++++++---- 8 files changed, 48 insertions(+), 17 deletions(-) diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py index 7468aaae8..666483319 100644 --- a/colossalai/nn/_ops/addmm.py +++ b/colossalai/nn/_ops/addmm.py @@ -13,7 +13,6 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()])) - # Output:P partial_output = torch.mm(mat1, mat2) # Reduce(Output) diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index f577d1af3..9594f8a6e 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -14,7 +14,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, sparse: bool = False) -> ColoTensor: # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # Gather splitted lookup table - input_tensor = input_tensor.redistribute(ReplicaSpec()) output_parallel = F.embedding(input_tensor, @@ -47,7 +46,6 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, # Find index in this shard and mask those not here # Reduce all pg = weight.get_process_group() - input_tensor = input_tensor.redistribute(ReplicaSpec()) # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py index 6491f673e..255a3d27e 100644 --- a/colossalai/nn/_ops/linear.py +++ b/colossalai/nn/_ops/linear.py @@ -32,9 +32,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option # All-Gather(Output) # Input:B compute_spec = weight.compute_spec - input_tensor = input_tensor.redistribute(ReplicaSpec()) - input_parallel = reduce_grad(input_tensor, weight.get_process_group()) output_parallel = F.linear(input_parallel, weight, bias) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 8cfb316b2..cc2e8dee3 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -186,7 +186,28 @@ class ColoTensor(torch.Tensor): self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group) self.dist_spec = dist_spec - def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor': + def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor': + """redistribute + Redistribute the tensor among processes. The rule is like this: + 1. If the pg is None, then redistributed tensor payload among TP process group. Keep the + DP process group still as replicated. + 2. If the pg is not not None and not equal to the cureent process group. + First, convert the tensor as replicated among TP process group. + Second, reset the process group. + Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec. + + Args: + dist_spec (_DistSpec): the new dist spec. + pg (Optional[ProcessGroup], optional): the new process group . Defaults to None. + + Returns: + ColoTensor: a redistributed colotensor + """ + if pg is not None and pg != self.get_process_group(): + print('here _redistribute') + # if the pg is not equal, convert the current tensor to replicated + self._redistribute(ReplicaSpec()) + self.process_group = pg ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group) return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec)) @@ -202,7 +223,6 @@ class ColoTensor(torch.Tensor): """ return self.redistribute(ReplicaSpec()) - @staticmethod def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor': tensor = tensor.as_subclass(ColoTensor) diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index 9a413ce33..1624638c4 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -117,13 +117,13 @@ class ProcessGroup: if not isinstance(obj, ProcessGroup): return False if self._rank != obj._rank: - assert False + return False if self._rank_list != obj._rank_list: - assert False + return False if self._tp_rank_list != obj._tp_rank_list: - assert False + return False if self._dp_rank_list != obj._dp_rank_list: - assert False + return False if self._tp_degree != obj._tp_degree: return False if self._dp_degree != obj._dp_degree: diff --git a/tests/test_tensor/test_module_spec.py b/tests/test_tensor/test_module_spec.py index 959204af6..a33af9c3c 100644 --- a/tests/test_tensor/test_module_spec.py +++ b/tests/test_tensor/test_module_spec.py @@ -164,7 +164,7 @@ def run_check_shared_param(): # TODO(jiaruifang) optimize this line if not model.cls.predictions.bias.has_initialized: model.cls.predictions.bias.pg = pg - model.cls.predictions.bias.dist_spec = distspec.replicate() + model.cls.predictions.bias.dist_spec = ReplicaSpec() model.cls.predictions.bias.has_initialized = True model.cls.predictions.bias.set_tensor_spec(*col_spec) try: diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 8f0ac55a0..8d3cf50ff 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -9,7 +9,6 @@ from colossalai.utils import get_current_device from torch.nn import Parameter from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.tensor import distspec def _run_layer_norm(): diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index 65aef9a25..addbf304d 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -5,7 +5,7 @@ from numpy import allclose import colossalai from colossalai.utils import free_port -from colossalai.tensor import distspec, ColoTensorSpec +from colossalai.tensor import ColoTensorSpec from colossalai.core import global_context as gpc import torch.multiprocessing as mp from colossalai.testing import rerun_if_address_is_in_use @@ -85,7 +85,7 @@ def _run_tensor_shard_init(world_size): shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()]) tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) - t.set_dist_spec(distspec.replicate()) + t.set_dist_spec(ReplicaSpec()) assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})" @@ -102,10 +102,26 @@ def _run_tensor_replicated_init(world_size): def _run_process_group(world_size): pg1 = ProcessGroup() pg2 = ProcessGroup() - assert pg1 == pg2 +def _run_redistributed(world_size): + if world_size != 4: + return + pg1 = ProcessGroup(tp_degree=2, dp_degree=2) + pg2 = ProcessGroup(tp_degree=4, dp_degree=1) + + spec1 = ColoTensorSpec(pg1) + t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1) + t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()])) + assert t1.is_sharded() + t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2) + assert t1.is_sharded() + pg3 = ProcessGroup(tp_degree=1, dp_degree=4) + t1 = t1.redistribute(ReplicaSpec(), pg3) + assert t1.is_replicate() + + def run_dist_tests(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') _run_tensor_shard_init(world_size) @@ -115,6 +131,7 @@ def run_dist_tests(rank, world_size, port): _run_tensor_indexing() _run_operand(world_size) _run_wrapped_tensor_func() + _run_redistributed(world_size) @pytest.mark.dist @@ -126,4 +143,4 @@ def test_dist_cases(world_size): if __name__ == '__main__': - test_dist_cases(1) + test_dist_cases(4)