diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 950f1f3a0..12f572ed7 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -9,11 +9,11 @@ from .optim.colo_optimizer import ColoOptimizer from . import distspec from .dist_spec_mgr import DistSpecManager from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module -from .modules import ColoLinear +from .modules import ColoLinear, ColoEmbedding __all__ = [ 'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', - 'ColoLinear' + 'ColoLinear', 'ColoEmbedding' ] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 8c09f088c..8d99a6a02 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -26,6 +26,13 @@ class ColoParameter(ColoTensor): self._type = TensorType.MODEL self._graph_node = None + # a list contains modules sharing this ColoParameter with others. + self._shared_param_modules = [] + + @property + def shared_param_modules(self): + return self._shared_param_modules + @staticmethod def from_torch_tensor(tensor: torch.Tensor, requires_grad: bool = True, @@ -36,3 +43,4 @@ class ColoParameter(ColoTensor): def __repr__(self): return f'ColoParameter: {torch.Tensor.__repr__(self)}' + diff --git a/colossalai/tensor/distspec.py b/colossalai/tensor/distspec.py index 905bf975e..4ded8fe45 100644 --- a/colossalai/tensor/distspec.py +++ b/colossalai/tensor/distspec.py @@ -30,6 +30,12 @@ class _DistSpec: return False return True + def __repr__(self) -> str: + res = "\nDistSpec:\n\t" + for attr in dir(self): + if not attr.startswith('__'): + res += f'{attr}: {str(getattr(self, attr))}\n\t' + return res def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec: # process_group=None means global process group diff --git a/colossalai/tensor/module_utils.py b/colossalai/tensor/module_utils.py index 6f449aa95..89c285e39 100644 --- a/colossalai/tensor/module_utils.py +++ b/colossalai/tensor/module_utils.py @@ -18,7 +18,6 @@ def get_colo_module(module: torch.nn.Module): global _COLOSSAL_MODULES if is_colo_module(module): colo_module = _COLOSSAL_MODULES[type(module)] - colo_module.register() return colo_module else: return None @@ -43,6 +42,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True): continue if compute_pattern is not None: + colo_module.register(compute_pattern) if not colo_module.has_compute_pattern(compute_pattern): raise Exception(f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.') @@ -65,28 +65,34 @@ def check_colo_module(module: torch.nn.Module, recursive=True): break if match_specs == False: raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.') - if recursive == True: for submodule in module.children(): check_colo_module(submodule, recursive=True) -def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, label='default'): +def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, mode='default'): compute_pattern = parallel_action.compute_pattern if is_colo_module(module): # for each param # set DistSpec and ParallelAction colo_module = get_colo_module(module) - if not colo_module.has_compute_pattern_with_label(compute_pattern, label=label): + colo_module.register(compute_pattern) + if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode): raise NotImplementedError - for param_name, dist_spec in colo_module.get_dist_specs_with_label(compute_pattern, label=label).items(): + # a set for modules which update at least one param in the init process. + # these modules need to be checked whether all params still match one of the valid compute pattern. + modules_update_param = {module} + for param_name, dist_spec in colo_module.get_dist_specs_with_mode(compute_pattern, mode=mode).items(): if dist_spec is None: continue param = module.get_parameter(param_name) if isinstance(param, ColoParameter): spec = TensorSpec(dist_spec, parallel_action) param.set_spec(spec) - check_colo_module(module, recursive=False) + for mod in param.shared_param_modules: + modules_update_param.add(mod) + for mod in modules_update_param: + check_colo_module(mod, recursive=False) if recursive == True: for submodule in module.children(): - init_colo_module(submodule, parallel_action, recursive=True, label=label) + init_colo_module(submodule, parallel_action, recursive=True, mode=mode) \ No newline at end of file diff --git a/colossalai/tensor/modules/__init__.py b/colossalai/tensor/modules/__init__.py index 15f10534e..3d6a0e69b 100644 --- a/colossalai/tensor/modules/__init__.py +++ b/colossalai/tensor/modules/__init__.py @@ -1,2 +1,3 @@ from .colo_module import ColoModule -from .linear import ColoLinear \ No newline at end of file +from .linear import ColoLinear +from .embedding import ColoEmbedding \ No newline at end of file diff --git a/colossalai/tensor/modules/colo_module.py b/colossalai/tensor/modules/colo_module.py index c0aa37e48..ecdfc1a59 100644 --- a/colossalai/tensor/modules/colo_module.py +++ b/colossalai/tensor/modules/colo_module.py @@ -21,14 +21,14 @@ class ColoModule(object): def _register_shard_params(self, params: List[str]): self._shard_params = params - def _register_allowed_patterns(self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], label='default'): + def _register_allowed_patterns(self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode='default'): assert list(dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.' if not compute_pattern in self._allowed_patterns: self._allowed_patterns[compute_pattern] = {} - self._allowed_patterns[compute_pattern][label] = dist_specs + self._allowed_patterns[compute_pattern][mode] = dist_specs - def _set_default(self, compute_pattern: ComputePattern, target_label): - self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_label] + def _set_default(self, compute_pattern: ComputePattern, target_mode): + self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode] def has_compute_pattern(self, compute_pattern: ComputePattern): return compute_pattern in self._allowed_patterns @@ -37,15 +37,15 @@ class ColoModule(object): assert self.has_compute_pattern(compute_pattern) return self._allowed_patterns[compute_pattern] - def has_compute_pattern_with_label(self, compute_pattern: ComputePattern, label='default'): - return compute_pattern in self._allowed_patterns and label in self._allowed_patterns[compute_pattern] + def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'): + return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern] - def get_dist_specs_with_label(self, compute_pattern: ComputePattern, label='default'): - assert self.has_compute_pattern_with_label(compute_pattern, label) - return self._allowed_patterns[compute_pattern][label] + def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'): + assert self.has_compute_pattern_with_mode(compute_pattern, mode) + return self._allowed_patterns[compute_pattern][mode] def get_param_names(self): return self._shard_params - def register(self): + def register(self, compute_pattern): raise NotImplementedError \ No newline at end of file diff --git a/colossalai/tensor/modules/embedding.py b/colossalai/tensor/modules/embedding.py new file mode 100644 index 000000000..b48193971 --- /dev/null +++ b/colossalai/tensor/modules/embedding.py @@ -0,0 +1,36 @@ +from .colo_module import ColoModule +from colossalai.tensor import ComputePattern, distspec +from colossalai.core import global_context as gpc +from colossalai.context.parallel_mode import ParallelMode + +class ColoEmbedding(ColoModule): + def __init__(self): + super(ColoEmbedding, self).__init__() + self._register_shard_params(['weight']) + + def register(self, compute_pattern): + if not compute_pattern in self._allowed_patterns: + if ComputePattern.TP1D == compute_pattern: + self._set_TP1D() + + def _set_TP1D(self): + # TP1D Row Linear + _compute_pattern = ComputePattern.TP1D + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + }, + mode='row', + ) + + # TP1D Col Linear + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + 'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + }, + mode='col', + ) + + self._set_default(compute_pattern=_compute_pattern, target_mode='row') \ No newline at end of file diff --git a/colossalai/tensor/modules/linear.py b/colossalai/tensor/modules/linear.py index 1ff22fdd9..239cbeb1e 100644 --- a/colossalai/tensor/modules/linear.py +++ b/colossalai/tensor/modules/linear.py @@ -7,12 +7,11 @@ class ColoLinear(ColoModule): def __init__(self): super(ColoLinear, self).__init__() self._register_shard_params(['weight', 'bias']) - self._register = False - def register(self): - if self._register == False: - self._set_TP1D() - self._register = True + def register(self, compute_pattern): + if not compute_pattern in self._allowed_patterns: + if ComputePattern.TP1D == compute_pattern: + self._set_TP1D() def _set_TP1D(self): # TP1D Row Linear @@ -23,7 +22,7 @@ class ColoLinear(ColoModule): 'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), 'bias': None }, - label='row', + mode='row', ) # TP1D Col Linear @@ -33,7 +32,7 @@ class ColoLinear(ColoModule): 'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), 'bias': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]) }, - label='col', + mode='col', ) - self._set_default(compute_pattern=_compute_pattern, target_label='row') + self._set_default(compute_pattern=_compute_pattern, target_mode='row') diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 53aabc3f2..0f87cc289 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -1,7 +1,7 @@ from .utils import InsertPostInitMethodToModuleSubClasses import torch from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \ - ColoLinear + ColoLinear, ColoEmbedding import types from torch import nn @@ -137,7 +137,12 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): torch.nn.Module.__setattr__ = _setattr_with_colotensor torch.nn.Module.register_parameter = _register_parameter_with_colotensor torch.nn.Module.get_parameter = _get_parameter_with_colotensor + + self._register_colo_modules() + + def _register_colo_modules(self): register_colo_module(torch.nn.Linear, ColoLinear()) + register_colo_module(torch.nn.Embedding, ColoEmbedding()) def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ @@ -179,5 +184,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): replaced_tensors[param] = colo_param delattr(submodule, param_name) setattr(submodule, param_name, colo_param) + colo_param.shared_param_modules.append(submodule) ColoModulize(module) diff --git a/tests/test_tensor/test_module_spec.py b/tests/test_tensor/test_module_spec.py index 478aa815a..771662571 100644 --- a/tests/test_tensor/test_module_spec.py +++ b/tests/test_tensor/test_module_spec.py @@ -15,21 +15,21 @@ import torch.nn.functional as F from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.core import global_context as gpc -from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, register_colo_module, init_colo_module, ColoLinear +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, register_colo_module, init_colo_module, check_colo_module from _utils import tensor_equal, tensor_shard_equal, set_seed from tests.components_to_test.registry import non_distributed_component_funcs -def run_simplenet_with_spec(label): - get_components_func = non_distributed_component_funcs.get_callable('simple_net') +def run_model_with_spec(mode, model_name): + get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) set_seed(1) with ColoInitContext(device=get_current_device()): - model = model_builder(checkpoint=True) + model = model_builder(checkpoint=False) if rank == 0: - model_seq = model_builder(checkpoint=True) + model_seq = model_builder(checkpoint=False) model_seq = model_seq.cuda() # Make two models have the same init params @@ -37,7 +37,19 @@ def run_simplenet_with_spec(label): p2.data.copy_(p1.data) parallel_action = ParallelAction(ComputePattern.TP1D) - init_colo_module(model, parallel_action, recursive=True, label=label) + # Not all layers in Bert can be mod by 4. + # e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2. + if 'bert' == model_name: + if 'col' == mode: + init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode=mode) + init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode) + init_colo_module(model.classifier, parallel_action, recursive=True, mode='row') + elif 'row' == mode: + init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode='col') + init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode) + init_colo_module(model.classifier, parallel_action, recursive=True, mode=mode) + elif 'simple_net' == model_name: + init_colo_module(model, parallel_action, recursive=True, mode=mode) model = model.cuda() for i, (data, label) in enumerate(train_dataloader): @@ -91,14 +103,14 @@ def run_simplenet_with_spec(label): if i > 3: break -def run_linear_with_spec(label): +def run_linear_with_spec(mode): with ColoInitContext(device=get_current_device()): model = torch.nn.Linear(4, 8) model_handy = copy(model) parallel_action = ParallelAction(ComputePattern.TP1D) - init_colo_module(model, parallel_action, recursive=True, label=label) + init_colo_module(model, parallel_action, recursive=True, mode=mode) x = torch.rand(2, 4).cuda() out = model(x) @@ -110,28 +122,79 @@ def run_linear_with_spec(label): assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad) assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad) +def run_check_shared_param(): + from transformers import BertForMaskedLM, BertConfig + hidden_dim = 8 + num_head = 4 + sequence_length = 12 + num_layer = 2 + vocab_size = 24 -def run_dist(rank, world_size, port, func): + config = BertConfig(vocab_size=vocab_size, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0., + attention_probs_dropout_prob=0.) + with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): + model = BertForMaskedLM(config) + + model = model.cuda() + parallel_action = ParallelAction(ComputePattern.TP1D) + # model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec + assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2 + # They are all Linear, so both row is allowed. This should pass check. + init_colo_module(model, parallel_action, recursive=True, mode='row') + # This should be detected by check because you can not set weight as row while set bias as col. + col_spec = TensorSpec( + distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + ParallelAction(ComputePattern.TP1D)) + model.cls.predictions.bias.set_spec(col_spec) + try: + check_colo_module(model.cls.predictions.decoder, recursive=False) + except Exception as e: + assert 'incorrectly sharded' in str(e) + +def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - func('col') - func('row') - func('default') + run_linear_with_spec('col') + run_linear_with_spec('row') +def run_dist_model(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + for model_name in ['simple_net', 'bert']: + run_model_with_spec('col', model_name) + run_model_with_spec('row', model_name) + +def run_dist_check(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_check_shared_param() @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_module_linear_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port(), func=run_linear_with_spec) + run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() -def test_module_simplenet(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port(), func=run_simplenet_with_spec) +def test_module_model(world_size): + run_func = partial(run_dist_model, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_module_check(world_size): + run_func = partial(run_dist_check, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_module_simplenet(4) \ No newline at end of file + test_module_check(2) \ No newline at end of file