From ab95ec9aea1c061c2c31a77df0f3f40db62eedcf Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 6 May 2022 12:57:14 +0800 Subject: [PATCH] [Tensor] init ColoParameter (#914) --- colossalai/tensor/__init__.py | 3 +- colossalai/tensor/colo_parameter.py | 28 ++++++++++++ colossalai/tensor/colo_tensor.py | 47 ++++++++------------- colossalai/tensor/const.py | 6 +++ colossalai/utils/model/colo_init_context.py | 9 ++-- tests/test_tensor/test_model.py | 28 +++++++++--- 6 files changed, 77 insertions(+), 44 deletions(-) create mode 100644 colossalai/tensor/colo_parameter.py create mode 100644 colossalai/tensor/const.py diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 7d3d168bf..6fb00800a 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -2,11 +2,12 @@ from .spec import ComputePattern, ParallelAction, TensorSpec, ShardPattern from .op_wrapper import ( colo_op_impl,) from .colo_tensor import ColoTensor +from .colo_parameter import ColoParameter from .utils import convert_parameter, named_params_with_colotensor from ._ops import * from .optim.colo_optimizer import ColoOptimizer __all__ = [ 'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction', - 'named_params_with_colotensor', 'ShardPattern', 'ColoOptimizer' + 'named_params_with_colotensor', 'ShardPattern', 'ColoOptimizer', 'ColoParameter' ] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py new file mode 100644 index 000000000..9662affb6 --- /dev/null +++ b/colossalai/tensor/colo_parameter.py @@ -0,0 +1,28 @@ +from .colo_tensor import ColoTensor +from .const import TensorType +import torch + + +class ColoParameter(ColoTensor): + r"""A kind of ColoTensor to be considered as a module parameter. + + """ + + def __init__(self, *args, **kargs): + super().__init__(*args, **kargs) + self._type = TensorType.MODEL + + def __new__(cls, *args, **kwargs): + t = super(ColoParameter, cls).__new__(cls) + t._type = TensorType.MODEL + return t + + @staticmethod + def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoParameter': + colo_p = ColoParameter(*tensor.size(), + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + pin_memory=tensor.is_pinned(), + device=tensor.device, + torch_tensor=tensor if save_payload else torch.empty(0)) + return colo_p diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 5416e1662..5a7c06d64 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -7,12 +7,7 @@ from colossalai.core import global_context as gpc from colossalai.nn.layer.utils import divide from colossalai.tensor import TensorSpec, ComputePattern, ShardPattern from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, gather_forward_split_backward -from enum import Enum - - -class TensorType(Enum): - MODEL = 0 - NONMODEL = 1 # mainly activations +from .const import TensorType class ColoTensor(object): @@ -26,17 +21,14 @@ class ColoTensor(object): def __new__(cls, *args, **kwargs): return super(ColoTensor, cls).__new__(cls) - def __init__( - self, - *size: Tuple[int], - dtype=None, - requires_grad=False, - pin_memory=False, - device=None, - torch_tensor=torch.empty(0), - shard_spec: TensorSpec = TensorSpec(), - is_model_data: bool = False, - ): + def __init__(self, + *size: Tuple[int], + dtype=None, + requires_grad=False, + pin_memory=False, + device=None, + torch_tensor=torch.empty(0), + shard_spec: TensorSpec = TensorSpec()): self._size = size self._dtype = dtype self._requires_grad = requires_grad @@ -45,10 +37,7 @@ class ColoTensor(object): self._torch_tensor = torch_tensor self._shard_spec = shard_spec self._shard_pattern = ShardPattern.NA - if is_model_data: - self._type = TensorType.MODEL - else: - self._type = TensorType.NONMODEL + self._type = TensorType.NONMODEL def __getitem__(self, key): return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key]) @@ -97,14 +86,13 @@ class ColoTensor(object): return product(self._size) @staticmethod - def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True, is_model_data=False) -> 'ColoTensor': + def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor': colo_t = ColoTensor(*tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, pin_memory=tensor.is_pinned(), device=tensor.device, - torch_tensor=tensor if save_payload else torch.empty(0), - is_model_data=is_model_data) + torch_tensor=tensor if save_payload else torch.empty(0)) return colo_t def del_torch_tensor(self, save_shape=False) -> None: @@ -143,12 +131,11 @@ class ColoTensor(object): self.gather() # Model Parameters if self._shard_spec.num_action == 1: - parallel_action = self._shard_spec.get_action_by_compute_pattern( - self._shard_spec.compute_patterns[0]) + parallel_action = self._shard_spec.get_action_by_compute_pattern(self._shard_spec.compute_patterns[0]) if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \ ComputePattern.TP1DCol_Embedding]: self._shard_1d(parallel_action=parallel_action, dim=-1) - self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear(). + self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear(). elif parallel_action.compute_pattern in [ComputePattern.TP1DCol_Linear, \ ComputePattern.TP1DRow_Embedding]: self._shard_1d(parallel_action=parallel_action, dim=0) @@ -157,7 +144,7 @@ class ColoTensor(object): raise NotImplementedError def gather(self): - assert self.is_activation(), 'Currently we only support gather Activation ColoTensor.' + assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.' assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.' parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP) if self._shard_pattern == ShardPattern.Row: @@ -174,8 +161,8 @@ class ColoTensor(object): def has_spec(self) -> bool: return self._shard_spec is not None and self._shard_spec.num_action > 0 - def is_activation(self) -> bool: - return self._type == TensorType.NONMODEL + def is_model_data(self) -> bool: + return self._type == TensorType.MODEL def _shard_1d(self, parallel_action, dim=-1): num_partition = gpc.get_world_size(parallel_action.parallel_mode) diff --git a/colossalai/tensor/const.py b/colossalai/tensor/const.py new file mode 100644 index 000000000..356e8ecc8 --- /dev/null +++ b/colossalai/tensor/const.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class TensorType(Enum): + MODEL = 0 + NONMODEL = 1 # mainly activations diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 7efa0c338..5853e369d 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -1,6 +1,6 @@ from .utils import InsertPostInitMethodToModuleSubClasses import torch -from colossalai.tensor import ColoTensor +from colossalai.tensor import ColoTensor, ColoParameter import types from torch import nn @@ -100,10 +100,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): tensor_detached = param.to(self._device).detach() tensor_detached.requires_grad = requires_grad - setattr( - module, name, - ColoTensor.init_from_torch_tensor(tensor=tensor_detached, - save_payload=save_torch_payload, - is_model_data=True)) + setattr(module, name, + ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload)) ColoModulize(module) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index b56892e6d..2fd10de06 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -38,17 +38,23 @@ def run_1d_col_tp(): model = model_builder(checkpoint=True) parallel_action_list_row = [ - ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D) + ParallelAction(priority=1, + compute_pattern=ComputePattern.TP1DRow_Linear, + parallel_mode=ParallelMode.PARALLEL_1D) ] spec_row = TensorSpec(parallel_action_list_row) parallel_action_list_col = [ - ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Linear, parallel_mode=ParallelMode.PARALLEL_1D) + ParallelAction(priority=1, + compute_pattern=ComputePattern.TP1DCol_Linear, + parallel_mode=ParallelMode.PARALLEL_1D) ] spec_col = TensorSpec(parallel_action_list_col) parallel_action_list_embedding_col = [ - ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding, parallel_mode=ParallelMode.PARALLEL_1D) + ParallelAction(priority=1, + compute_pattern=ComputePattern.TP1DCol_Embedding, + parallel_mode=ParallelMode.PARALLEL_1D) ] spec_embedding_col = TensorSpec(parallel_action_list_embedding_col) @@ -125,6 +131,9 @@ def test_model_parameters(): param_cnt += 1 assert param_cnt == 5 + for name, colo_p in model.colo_named_parameters(): + assert colo_p.is_model_data() + param_cnt = 0 for name, p in model.named_parameters(recurse=False): param_cnt += 1 @@ -175,12 +184,16 @@ def run_1d_row_tp(): model = model_builder(checkpoint=True) parallel_action_list = [ - ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D) + ParallelAction(priority=1, + compute_pattern=ComputePattern.TP1DRow_Linear, + parallel_mode=ParallelMode.PARALLEL_1D) ] spec = TensorSpec(parallel_action_list) parallel_action_list_embedding_row = [ - ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Embedding, parallel_mode=ParallelMode.PARALLEL_1D) + ParallelAction(priority=1, + compute_pattern=ComputePattern.TP1DRow_Embedding, + parallel_mode=ParallelMode.PARALLEL_1D) ] spec_embedding_row = TensorSpec(parallel_action_list_embedding_row) @@ -243,6 +256,7 @@ def run_dist(rank, world_size, port): run_1d_row_tp() run_1d_col_tp() + @pytest.mark.dist @parameterize('world_size', [1, 4]) @rerun_if_address_is_in_use() @@ -252,6 +266,6 @@ def test_simple_net(world_size): if __name__ == '__main__': - test_simple_net() - # test_model_parameters() + # test_simple_net() + test_model_parameters() # test_colo_optimizer()