From 797a9dc5a9e801d7499b8667c3ef039a38aa15ba Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Fri, 13 May 2022 20:29:50 +0800 Subject: [PATCH] add DistSpec for loss and test_model (#947) --- colossalai/tensor/_ops/__init__.py | 2 +- colossalai/tensor/_ops/layernorm.py | 2 +- colossalai/tensor/_ops/loss.py | 9 +- colossalai/tensor/dist_spec_mgr.py | 6 +- colossalai/tensor/spec.py | 15 +++- tests/test_tensor/test_model.py | 124 +++++++++------------------- 6 files changed, 64 insertions(+), 94 deletions(-) diff --git a/colossalai/tensor/_ops/__init__.py b/colossalai/tensor/_ops/__init__.py index 2e09e15ba..e9ce2b1ff 100644 --- a/colossalai/tensor/_ops/__init__.py +++ b/colossalai/tensor/_ops/__init__.py @@ -1,6 +1,6 @@ from .linear import colo_linear from .element_wise import * from .layernorm import colo_layernorm -# from .loss import colo_cross_entropy +from .loss import colo_cross_entropy from .embedding import colo_embedding from .addmm import colo_addmm diff --git a/colossalai/tensor/_ops/layernorm.py b/colossalai/tensor/_ops/layernorm.py index 4eeafc635..1879a0953 100644 --- a/colossalai/tensor/_ops/layernorm.py +++ b/colossalai/tensor/_ops/layernorm.py @@ -28,7 +28,7 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None): if isinstance(input_tensor, ColoTensor): # TODO (ver217): check input dist spec - input_tensor.to_dist_spec(dist_spec.replicate()) + input_tensor.to_dist_spec(dist_spec.replicate(input_tensor.spec.get_process_group())) input_tensor = input_tensor.torch_tensor() if isinstance(weight, ColoTensor): weight = weight.torch_tensor() diff --git a/colossalai/tensor/_ops/loss.py b/colossalai/tensor/_ops/loss.py index 89683d3aa..6243301fd 100644 --- a/colossalai/tensor/_ops/loss.py +++ b/colossalai/tensor/_ops/loss.py @@ -1,4 +1,4 @@ -from colossalai.tensor.spec import ShardPattern +from colossalai.tensor.dist_spec import DistPlacementPattern import torch from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor import ColoTensor @@ -27,12 +27,11 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None): if isinstance(target, ColoTensor): target = target.torch_tensor() - if input_tensor.is_gathered(): # Input is gathered - # TODO(jzy) Shall we make the result of loss function a ColoTensor? + if input_tensor.spec.is_gathered(): # Input is gathered return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy( input_tensor.torch_tensor(), target, weight)) - elif input_tensor.has_spec() and input_tensor.shard_spec.num_action == 1: # Single Model Parallel Applied - if input_tensor.shard_pattern == ShardPattern.Col: + elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied + if input_tensor.spec.is_1Dcol(): return ColoTensor.init_from_torch_tensor( VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target)) else: diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py index ba32d1bd1..ef4a1a359 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/tensor/dist_spec_mgr.py @@ -53,7 +53,8 @@ class DistSpecManager: @staticmethod def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor: - if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group: + if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \ + and dist_spec.process_group is not None: raise NotImplementedError return tensor @@ -65,7 +66,8 @@ class DistSpecManager: @staticmethod def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor: - if old_dist_spec.process_group != dist_spec.process_group: + if old_dist_spec.process_group != dist_spec.process_group \ + and dist_spec.process_group is not None: raise NotImplementedError return DistSpecManager._gather(tensor, old_dist_spec) diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index ddb5401c6..d1b762f0b 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List from colossalai.context.parallel_mode import ParallelMode -from colossalai.tensor.dist_spec import _DistSpec +from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern class ComputePattern(Enum): @@ -84,3 +84,16 @@ class TensorSpec(object): def get_process_group(self): return self.dist_spec.process_group + + def get_placement(self): + return self.dist_spec.placement + + def is_gathered(self): + return self.dist_spec.placement == DistPlacementPattern.REPLICATE \ + or (len(self.dist_spec.num_partitions) == 1 + and self.dist_spec.num_partitions[0] == 1) \ + or (self.dist_spec.process_group.size() == 1) + + def is_1Dcol(self): + return self.dist_spec.placement == DistPlacementPattern.SHARD \ + and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1 \ No newline at end of file diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 1fbcf29ab..9cebf72db 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -9,7 +9,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils import ColoInitContext -from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer +from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, \ + ParallelAction, ColoTensor, ColoOptimizer, dist_spec, DistSpecManager from colossalai.context import ParallelMode from colossalai.core import global_context as gpc @@ -85,6 +86,34 @@ def set_seed(seed): torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True +def init_1d_row_linear(weight): + spec = TensorSpec( + dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)]) + with DistSpecManager.no_grad(): + weight.set_spec(spec) + +def init_1d_col_linear(weight, gather_out=True): + spec = TensorSpec( + dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D, \ + gather_out=gather_out)]) + with DistSpecManager.no_grad(): + weight.set_spec(spec) + +def init_1d_row_embedding(weight): + spec = TensorSpec( + dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)]) + with DistSpecManager.no_grad(): + weight.set_spec(spec) + +def init_1d_col_embedding(weight): + spec = TensorSpec( + dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)]) + with DistSpecManager.no_grad(): + weight.set_spec(spec) def run_1d_hybrid_tp(model_name): # A simple net with two stacked nn.Linear @@ -106,84 +135,35 @@ def run_1d_hybrid_tp(model_name): p2.data.copy_(p1.data) if 'bert' == model_name: - parallel_action_list_row = [ - ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1DRow_Linear, - parallel_mode=ParallelMode.PARALLEL_1D) - ] - spec_linear_row = TensorSpec(parallel_action_list_row) - - parallel_action_list_embedding_col = [ - ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1DCol_Embedding, - parallel_mode=ParallelMode.PARALLEL_1D) - ] - spec_embedding_col = TensorSpec(parallel_action_list_embedding_col) - - parallel_action_list_embedding_row = [ - ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1DRow_Embedding, - parallel_mode=ParallelMode.PARALLEL_1D) - ] - spec_embedding_row = TensorSpec(parallel_action_list_embedding_row) - for name, p in model.colo_named_parameters(): if not isinstance(p, ColoTensor): continue # print(name) # num_class = type_vocab_size = 2 | (8, 2) if 'classifier' in name and 'weight' in name: - p.set_spec(spec_linear_row) + init_1d_row_linear(p) # num_class = vocab_size = 30524 | (30524, 8) if 'word_embeddings' in name and 'weight' in name: - p.set_spec(spec_embedding_row) + init_1d_row_embedding(p) # num_class = seq_len = 512 | (512, 8) if 'position_embeddings' in name and 'weight' in name: - p.set_spec(spec_embedding_row) + init_1d_row_embedding(p) # num_class = type_vocab_size = 2 | (2, 8) if 'token_type_embeddings' in name and 'weight' in name: - p.set_spec(spec_embedding_col) + init_1d_col_embedding(p) elif "simple_net" == model_name: - parallel_action_list_row = [ - ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1DRow_Linear, - parallel_mode=ParallelMode.PARALLEL_1D) - ] - spec_row = TensorSpec(parallel_action_list_row) - - parallel_action_list_col = [ - ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1DCol_Linear, - parallel_mode=ParallelMode.PARALLEL_1D), - ] - spec_col = TensorSpec(parallel_action_list_col) - - parallel_action_list_classifier_col = [ - ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1DCol_Linear, - parallel_mode=ParallelMode.PARALLEL_1D, - gather_out=False), - ] - spec_classifier_col = TensorSpec(parallel_action_list_classifier_col) - - parallel_action_list_embedding_col = [ - ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1DCol_Embedding, - parallel_mode=ParallelMode.PARALLEL_1D) - ] - spec_embedding_col = TensorSpec(parallel_action_list_embedding_col) # A naive way to set spec for all weights in Linear for name, p in model.colo_named_parameters(): if not isinstance(p, ColoTensor): continue if 'embed' in name and 'weight' in name: - p.set_spec(spec_embedding_col) + init_1d_col_embedding(p) if 'proj1' in name and ('weight' in name or 'bias' in name): - p.set_spec(spec_col) + init_1d_col_linear(p) if 'proj2' in name and 'weight' in name: - p.set_spec(spec_row) + init_1d_row_linear(p) if 'classifier' in name and ('weight' in name or 'bias' in name): - p.set_spec(spec_classifier_col) + init_1d_col_linear(p, gather_out=False) model = model.cuda() colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) @@ -251,8 +231,6 @@ def run_1d_hybrid_tp(model_name): break -# FIXME (ver217): enable this test -@pytest.mark.skip # Test the overrided parameters() and named_parameters() member functions def test_model_parameters(): # build a module with 2 Linear, 4 parameters in total. @@ -285,8 +263,6 @@ def test_model_parameters(): assert param_cnt == 2 -# FIXME (ver217): enable this test -@pytest.mark.skip def test_colo_optimizer(): get_components_func = non_distributed_component_funcs.get_callable('simple_net') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -329,29 +305,14 @@ def run_1d_row_tp(model_name: str): if rank == 0: model_torch = model_builder(checkpoint=True) model_torch = model_torch.cuda() - - parallel_action_list = [ - ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1DRow_Linear, - parallel_mode=ParallelMode.PARALLEL_1D) - ] - spec = TensorSpec(parallel_action_list) - - parallel_action_list_embedding_row = [ - ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1DRow_Embedding, - parallel_mode=ParallelMode.PARALLEL_1D) - ] - spec_embedding_row = TensorSpec(parallel_action_list_embedding_row) - # A naive way to set spec for all weights in Linear for name, p in model.colo_named_parameters(): if not isinstance(p, ColoTensor): continue if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name: - p.set_spec(spec) + init_1d_row_linear(p) if 'embed' in name and 'weight' in name: - p.set_spec(spec_embedding_row) + init_1d_row_embedding(p) model = model.cuda() @@ -434,9 +395,6 @@ def run_model_dist(rank, world_size, port): for name in ['bert', 'simple_net']: run_1d_hybrid_tp(name) - -# FIXME (ver217): enable this test -@pytest.mark.skip @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) # @parameterize('world_size', [1, 4]) @@ -454,8 +412,6 @@ def run_pretrain_load_dist(rank, world_size, port): # The test case has to download huggingface pretrained models from the internet # So we manually trigger the test. -# FIXME (ver217): enable this test -@pytest.mark.skip @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use()