From 2b2dc1c86becfa253efbc7bf8a6e8d3dd39b08b8 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 10 Jun 2022 11:27:38 +0800 Subject: [PATCH] [pipeline] refactor the pipeline module (#1087) * [pipeline] refactor the pipeline module * polish code --- colossalai/builder/__init__.py | 11 +- colossalai/builder/builder.py | 176 ------------ colossalai/builder/pipeline.py | 266 ------------------ colossalai/nn/__init__.py | 1 - colossalai/nn/layer/wrapper/__init__.py | 3 +- colossalai/nn/layer/wrapper/lambda_wrapper.py | 36 --- colossalai/nn/model/__init__.py | 3 - colossalai/nn/model/model_from_config.py | 37 --- colossalai/pipeline/__init__.py | 4 + colossalai/pipeline/layer_sepc.py | 55 ++++ .../{utils/model => pipeline}/pipelinable.py | 184 ++++-------- colossalai/pipeline/utils.py | 207 ++++++++++++++ tests/test_config/test_load_config.py | 1 - .../test_cifar_with_data_pipeline_tensor.py | 20 +- .../test_pipelinable.py | 4 +- .../test_pipeline/model/__init__.py | 2 - .../test_pipeline/model/layers/__init__.py | 3 - .../test_pipeline/model/layers/basic_block.py | 64 ----- .../test_pipeline/model/layers/bottleneck.py | 69 ----- .../test_pipeline/model/layers/conv.py | 15 - .../test_pipeline/model/layers/reslayer.py | 63 ----- .../test_pipeline/model/resnet.py | 163 ----------- .../test_pipeline/resnet_config.py | 21 -- .../test_pipeline/test_partition.py | 43 --- .../test_pipeline/test_pipeline_schedule.py | 34 ++- .../test_checkpoint/test_checkpoint_1d.py | 2 +- .../test_checkpoint/test_checkpoint_2d.py | 2 +- .../test_checkpoint/test_checkpoint_2p5d.py | 2 +- .../test_checkpoint/test_checkpoint_3d.py | 2 +- 29 files changed, 366 insertions(+), 1127 deletions(-) delete mode 100644 colossalai/builder/pipeline.py delete mode 100644 colossalai/nn/layer/wrapper/lambda_wrapper.py delete mode 100644 colossalai/nn/model/__init__.py delete mode 100644 colossalai/nn/model/model_from_config.py create mode 100644 colossalai/pipeline/__init__.py create mode 100644 colossalai/pipeline/layer_sepc.py rename colossalai/{utils/model => pipeline}/pipelinable.py (64%) create mode 100644 colossalai/pipeline/utils.py rename tests/{test_utils => test_pipeline}/test_pipelinable.py (94%) delete mode 100644 tests/test_trainer/test_pipeline/model/__init__.py delete mode 100644 tests/test_trainer/test_pipeline/model/layers/__init__.py delete mode 100644 tests/test_trainer/test_pipeline/model/layers/basic_block.py delete mode 100644 tests/test_trainer/test_pipeline/model/layers/bottleneck.py delete mode 100644 tests/test_trainer/test_pipeline/model/layers/conv.py delete mode 100644 tests/test_trainer/test_pipeline/model/layers/reslayer.py delete mode 100644 tests/test_trainer/test_pipeline/model/resnet.py delete mode 100644 tests/test_trainer/test_pipeline/resnet_config.py delete mode 100644 tests/test_trainer/test_pipeline/test_partition.py diff --git a/colossalai/builder/__init__.py b/colossalai/builder/__init__.py index c4840c24a..3edb4b6bc 100644 --- a/colossalai/builder/__init__.py +++ b/colossalai/builder/__init__.py @@ -1,12 +1,5 @@ -from .builder import (build_schedule, build_lr_scheduler, build_model, - build_optimizer, build_layer, build_loss, build_hooks, - build_dataset, build_transform, build_data_sampler, - build_gradient_handler, build_ophooks) -from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg +from .builder import build_from_config, build_from_registry, build_gradient_handler __all__ = [ - 'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', - 'build_layer', 'build_loss', 'build_hooks', 'build_dataset', - 'build_transform', 'build_data_sampler', 'build_gradient_handler', - 'build_pipeline_model', 'build_pipeline_model_from_cfg', 'build_ophooks' + 'build_gradient_handler', 'build_from_config', 'build_from_registry' ] diff --git a/colossalai/builder/builder.py b/colossalai/builder/builder.py index 812ab78d7..94cdba30c 100644 --- a/colossalai/builder/builder.py +++ b/colossalai/builder/builder.py @@ -2,7 +2,6 @@ # -*- encoding: utf-8 -*- import inspect -from collections.abc import Iterable from colossalai.registry import * @@ -64,84 +63,6 @@ def build_from_registry(config, registry: Registry): return obj - - -def build_layer(config): - """Returns a layer object of :class:`nn.Module` constructed from `config`. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``LAYERS``. - - Returns: - An object of :class:`torch.nn.Module` - """ - return build_from_registry(config, LAYERS) - - -def build_loss(config): - """Returns a loss function object of :class:`torch.autograd.Function` constructed - from `config`. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``LOSSES``. - - Returns: - An object of :class:`torch.nn.modules.loss._Loss` - """ - return build_from_registry(config, LOSSES) - - -def build_model(config): - """Returns a model object of :class:`nn.Module` constructed from `config`. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``MODELS``. - - Returns: - An object of :class:`torch.nn.Module` - """ - return build_from_registry(config, MODELS) - - -def build_dataset(config): - """Returns a dataset object of :class:`torch.utils.data.Dataset` constructed - from `config`. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``DATASETS``. - - Returns: - An object of :class:`torch.utils.data.Dataset` - """ - return build_from_registry(config, DATASETS) - - -def build_optimizer(config, model): - """Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`, - 'model' and 'params'. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``OPTIMIZERS``. - model (:class:`nn.Module`): A model containing parameters for the optimizer - - Returns: - An object of :class:`torch.optim.Optimizer` - """ - config_ = config.copy() - config_['params'] = model.parameters() - return build_from_registry(config_, OPTIMIZERS) - - def build_gradient_handler(config, model, optimizer): """Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`, `model` and `optimizer`. @@ -160,100 +81,3 @@ def build_gradient_handler(config, model, optimizer): config_['model'] = model config_['optimizer'] = optimizer return build_from_registry(config_, GRADIENT_HANDLER) - - -def build_hooks(config, trainer): - """Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``HOOKS``. - - Returns: - An object of :class:`colossalai.trainer.hooks.BaseHook` - """ - config_ = config.copy() - config_['trainer'] = trainer - return build_from_registry(config_, HOOKS) - - -def build_ophooks(config): - """Returns a hook object of :class:`BaseOpHook` constructed from `config`. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``OPHOOKS``. - - Returns: - An object of :class:`colossalai.trainer.hooks.BaseOpHook` - """ - config_ = config.copy() - return build_from_registry(config_, OPHOOKS) - - -def build_transform(config): - """Returns a transformation object of :class:`torchvision.transforms` constructed - from `config`. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``TRANSFORMS``. - - Returns: - An object of :class:`torchvision.transforms` - """ - return build_from_registry(config, TRANSFORMS) - - -def build_data_sampler(config, dataset): - """Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler` - constructed from `config`. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``DATA_SAMPLERS``. - dataset (:class:`torch.utils.data.Dataset`): An object of - :class:`torch.utils.data.Dataset` containing information - used in the construction of the return object - Returns: - An object of :class:`colossalai.utils.data_sampler.BaseSampler` - """ - config_ = config.copy() - config_['dataset'] = dataset - return build_from_registry(config_, DATA_SAMPLERS) - - -def build_lr_scheduler(config, optimizer): - """Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler` - constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``lr_schedule``. - optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing - parameters for the learning rate scheduler. - - Returns: - An object of :class:`torch.optim.lr_scheduler` - """ - config_ = config.copy() - config_['optimizer'] = optimizer - return build_from_registry(config_, LR_SCHEDULERS) - -def build_schedule(config): - """Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`. - - Args: - config (dict or :class:`colossalai.context.Config`): A python dict or - a :class:`colossalai.context.Config` object containing information - used in the construction of the ``Schedule``. - - Returns: - An object of :class:`colossalai.engine.schedule.BaseSchedule` - """ - return build_from_registry(config, SCHEDULE) diff --git a/colossalai/builder/pipeline.py b/colossalai/builder/pipeline.py deleted file mode 100644 index 6027d34e6..000000000 --- a/colossalai/builder/pipeline.py +++ /dev/null @@ -1,266 +0,0 @@ -import copy -import heapq - -from colossalai.builder import build_model, build_layer -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -import torch.nn as nn - - -def _binary_partition(weights, st, ed): - """Returns the binary partition position of `weights`, given the start - position `st` and the end position `ed`. - - Args: - weights (list): A python list to be binary partitioned - st (int): the start position of the binary partition - ed (int): the end position of the binary partition - - Returns: - int: the binary partition position of `weights` - """ - w_sum = weights[ed - 1] - prefix = 0 - if st > 0: - w_sum -= weights[st - 1] - prefix = weights[st - 1] - minimum = float("inf") - for idx in range(st + 1, ed): - front = weights[idx - 1] - prefix - diff = abs(w_sum - 2 * front) - if diff < minimum: - pos = idx - minimum = diff - - return st, pos, ed - - -def _heap_addition(weights, intervals, add_cnt): - """ - """ - - def _heap_push(heap, st, ed): - value = weights[ed - 1] - if st > 0: - value -= weights[st - 1] - heapq.heappush(heap, (-value, st, ed)) - - ret_intervals = [] - heap = [] - - for st, ed in intervals: - _heap_push(heap, st, ed) - - while add_cnt > 0: - _, st, ed = heapq.heappop(heap) - if ed - st == 1: - ret_intervals.append((st, ed)) - else: - l, m, r = _binary_partition(weights, st, ed) - _heap_push(heap, l, m) - _heap_push(heap, m, r) - add_cnt -= 1 - - while heap: - _, st, ed = heapq.heappop(heap) - ret_intervals.append((st, ed)) - - ret_intervals.sort() - return ret_intervals - - -def _calc_partitions(weights, value): - prev = 0 - prefix = 0 - num_block = 0 - intervals = [] - - for idx, w in enumerate(weights): - if weights[idx] - prefix > value: - intervals.append((prev, idx)) - prev = idx - prefix = weights[idx - 1] - num_block += 1 - - intervals.append((prev, len(weights))) - return num_block + 1, intervals - - -def _binary_search(weights, num): - length = len(weights) - prefix = [1 if w == 0 else w for w in weights] - for i in range(1, length): - prefix[i] += prefix[i - 1] - - lower_bound = max(weights) - upper_bound = prefix[length - 1] - - while upper_bound > lower_bound: - mid = (upper_bound + lower_bound) // 2 - number, _ = _calc_partitions(prefix, mid) - if number <= num: - upper_bound = mid - else: - lower_bound = mid + 1 - - num_block, intervals = _calc_partitions(prefix, upper_bound) - if num_block < num: - intervals = _heap_addition(prefix, intervals, num - num_block) - - return intervals - - -def partition_uniform(num_items, pipeline_parallel_size, num_chunks): - assert num_items % num_chunks == 0, \ - "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" - - logger = get_dist_logger() - parts = [[] for _ in range(pipeline_parallel_size)] - partition_items = num_items // num_chunks - for idx in range(num_chunks): - base_idx = idx * partition_items - chunk_size = partition_items // pipeline_parallel_size - left = pipeline_parallel_size - partition_items % pipeline_parallel_size - if chunk_size == 0: - logger.warning("Some nodes in Pipeline have no requests") - - for p in range(pipeline_parallel_size): - st = base_idx - base_idx += chunk_size + (p >= left) - parts[p].append((st, base_idx)) - - return parts - - -def partition_balanced(weights, pipeline_parallel_size, num_chunks): - num_total = pipeline_parallel_size * num_chunks - num_items = len(weights) - if num_items <= num_total: - return partition_uniform(num_items, pipeline_parallel_size, num_chunks) - - intervals = _binary_search(weights, num_total) - - current = 0 - parts = [[] for _ in range(pipeline_parallel_size)] - for inter in intervals: - parts[current].append(inter) - current = (current + 1) % pipeline_parallel_size - - return parts - - -def count_layer_params(layers): - """Count the number of parameters in each layer - """ - param_counts = [0] * len(layers) - for idx, cfg in enumerate(layers): - layer = build_layer(cfg) - params = filter(lambda p: p.requires_grad, layer.parameters()) - param_counts[idx] = sum(p.numel() for p in params) - - return param_counts - - -def build_pipeline_model_from_cfg(config, - num_chunks: int = 1, - partition_method: str = 'parameter', - verbose: bool = False): - """An initializer to split the model into different stages for pipeline parallelism. - - An example for the model config is shown below. The class VisionTransformerFromConfig should - inherit colossalai.nn.model.ModelFromConfig to allow this initializer to build model from a sequence - of layer configurations. - - :: - - model_config = dict( - type='VisionTransformerFromConfig', - embedding_cfg=dict(...), - ... - ) - - Args: - config (dict): Configuration of the model. - num_chunks (int, optional): The number of chunks you want to have on the current stage. - This value should be 1 in most cases unless you are using virtual pipeline parallelism. - partition_method (str, optional): This parameter determines how you want to split your model - layers into stages, you can set it as 'layer' or 'parameter'. - verbose (bool, optional): Whether to print the logs. - """ - ori_model = build_model(config) - layers = ori_model.layers_cfg - layer_length = len(layers) - logger = get_dist_logger() - if verbose: - logger.info(f"The total length of layers is {layer_length}", ranks=[0]) - - pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - method = partition_method.lower() - # Make a partition - if method == 'layer': - num_layers = len(layers) - parts = partition_uniform(num_layers, pipeline_parallel_size, num_chunks) - elif method == 'parameter': - param_counts = count_layer_params(layers) - # print_rank_0(param_counts) - parts = partition_balanced(param_counts, pipeline_parallel_size, num_chunks) - else: - raise ValueError("Method should be a pre-set string in [layer, parameter]") - - # Display the partition - if verbose: - log_str = 'Layer allocation after partitioning: \n' - for stage in range(pipeline_parallel_size): - - num_layers = 0 - for st, ed in parts[stage]: - num_layers += ed - st - - log_str += f'\n===== stage={stage}, layers={num_layers} =====\n' - for st, ed in parts[stage]: - for idx, layer in enumerate(layers[st:ed]): - log_str += f'\t{idx + st:2d}: {layer}\n' - logger.info(log_str, ranks=[0]) - - # Save the partition - interval = parts[pipeline_rank] - - models = [] - for st, ed in interval: - model = copy.deepcopy(ori_model) - model.build_from_cfg(st, ed) - models.append(model) - - return nn.ModuleList(models) if len(models) > 1 else models[0] - - -def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False): - """An intializer to split the model into different stages for pipeline parallelism. - Note that `layer` must be `torch.nn.Sequential`. - Args: - layers (`torch.nn.Sequential`): Layers of model - num_chunks: The number of chunks you want to have on the current stage. This value should be 1 - in most cases unless you are using virtual pipeline parallelism. - verbose (bool, optional): Whether to print the logs. - """ - pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks) - module_list = [] - for start, end in partitions[pipeline_rank]: - module_list.append( - nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end], - *[nn.Identity() for _ in range(len(layers) - end)])) - if verbose: - logger = get_dist_logger() - logger.info(f'Total {len(layers)} layers', ranks=[0]) - for rank, part in enumerate(partitions): - log_str = f'===== stage={rank} =====\n' - for chunk, (start, end) in enumerate(part): - log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n' - log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n' - logger.info(log_str, ranks=[0]) - return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0] diff --git a/colossalai/nn/__init__.py b/colossalai/nn/__init__.py index 13ae5187f..91fc0da55 100644 --- a/colossalai/nn/__init__.py +++ b/colossalai/nn/__init__.py @@ -2,6 +2,5 @@ from .layer import * from .loss import * from .lr_scheduler import * from .metric import * -from .model import * from .optimizer import * from ._ops import * diff --git a/colossalai/nn/layer/wrapper/__init__.py b/colossalai/nn/layer/wrapper/__init__.py index 01f746f65..c7d90d887 100644 --- a/colossalai/nn/layer/wrapper/__init__.py +++ b/colossalai/nn/layer/wrapper/__init__.py @@ -1,4 +1,3 @@ -from .lambda_wrapper import LambdaWrapper from .pipeline_wrapper import PipelineSharedModuleWrapper -__all__ = ['LambdaWrapper', 'PipelineSharedModuleWrapper'] +__all__ = ['PipelineSharedModuleWrapper'] diff --git a/colossalai/nn/layer/wrapper/lambda_wrapper.py b/colossalai/nn/layer/wrapper/lambda_wrapper.py deleted file mode 100644 index 6b0d6e1e9..000000000 --- a/colossalai/nn/layer/wrapper/lambda_wrapper.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch.nn as nn - -from colossalai.builder import build_layer -from colossalai.registry import LAYERS - - -@LAYERS.register_module -class LambdaWrapper(nn.Module): - """Wrap a function to nn.Module, which takes a config of layers and can fully access them. - - Args: - func (``Callable``): User customed function. - layers_cfg (dict, optional): Config of layers, defaults to None. - """ - - def __init__(self, func, layers_cfg: dict = None): - super().__init__() - self.func = func - self.layers = self._build_layers(layers_cfg) - - def _build_layers(self, layers_cfg: dict): - if layers_cfg is None: - return None - else: - layers = [] - - for cfg in layers_cfg: - layer = build_layer(cfg) - layers.append(layer) - return layers - - def forward(self, *args, **kwargs): - return self.func(self, *args, **kwargs) diff --git a/colossalai/nn/model/__init__.py b/colossalai/nn/model/__init__.py deleted file mode 100644 index 6ced17054..000000000 --- a/colossalai/nn/model/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .model_from_config import ModelFromConfig - -__all__ = ['ModelFromConfig'] diff --git a/colossalai/nn/model/model_from_config.py b/colossalai/nn/model/model_from_config.py deleted file mode 100644 index 24903ca36..000000000 --- a/colossalai/nn/model/model_from_config.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from abc import ABC, abstractmethod - -import torch.nn as nn - -from colossalai.builder import build_layer - - -class ModelFromConfig(nn.Module, ABC): - - def __init__(self): - super(ModelFromConfig, self).__init__() - self.layers = nn.ModuleList() - self.layers_cfg = [] - - def build_from_cfg(self, start=None, end=None): - assert hasattr(self, 'layers_cfg'), 'Cannot find attribute layers_cfg from the module, please check the ' \ - 'spelling and if you have initialized this variable' - if start is None: - start = 0 - if end is None: - end = len(self.layers_cfg) - for cfg in self.layers_cfg[start: end]: - layer = build_layer(cfg) - self.layers.append(layer) - - @abstractmethod - def init_weights(self): - pass - - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): - """Use this function to override the state dict for - saving checkpoints.""" - return self.state_dict(destination, prefix, keep_vars) diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py new file mode 100644 index 000000000..625bd7ef5 --- /dev/null +++ b/colossalai/pipeline/__init__.py @@ -0,0 +1,4 @@ +from .pipelinable import PipelinableContext, PipelinableModel +from .layer_sepc import LayerSpec + +__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] \ No newline at end of file diff --git a/colossalai/pipeline/layer_sepc.py b/colossalai/pipeline/layer_sepc.py new file mode 100644 index 000000000..7e9169eff --- /dev/null +++ b/colossalai/pipeline/layer_sepc.py @@ -0,0 +1,55 @@ +import torch +from colossalai.utils.model.utils import call_to_str + +class LayerSpec: + """ + + """ + + def __init__(self, typename, *module_args, **module_kwargs): + self.typename = typename + self.module_args = module_args + self.module_kwargs = module_kwargs + self.children = None + self._param_count = 0 + + if not issubclass(typename, torch.nn.Module): + raise RuntimeError('LayerSpec only supports torch.nn.Module types.') + + def __repr__(self): + return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs) + + @property + def param_count(self): + return self._param_count + + def build(self): + """Build the stored specification.""" + + recovered_args = [] + for obj in self.module_args: + if isinstance(obj, LayerSpec): + obj = obj.build() + recovered_args.append(obj) + recovered_args = tuple(recovered_args) + + recovered_kwargs = {} + for k, v in self.module_kwargs.items(): + if isinstance(v, LayerSpec): + v = v.build() + recovered_kwargs[k] = v + + return self.typename(*recovered_args, **recovered_kwargs) + + def set_children(self, children): + self.children = children + + def count_params(self): + self._param_count = 0 + layer = self.build() + for param in layer.parameters(): + self._param_count += param.numel() + return self._param_count + + def reset_param_count(self): + self._param_count = 0 \ No newline at end of file diff --git a/colossalai/utils/model/pipelinable.py b/colossalai/pipeline/pipelinable.py similarity index 64% rename from colossalai/utils/model/pipelinable.py rename to colossalai/pipeline/pipelinable.py index 14d19a7ae..317ac0c21 100644 --- a/colossalai/utils/model/pipelinable.py +++ b/colossalai/pipeline/pipelinable.py @@ -1,26 +1,34 @@ import torch import inspect -from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses, call_to_str -from colossalai.builder.pipeline import partition_uniform, partition_balanced +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses +from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs from colossalai.nn.layer.utils import CheckpointModule -from colossalai.tensor import ColoTensor +from colossalai.tensor import ColoParameter +from .layer_sepc import LayerSpec class PipelinableContext(InsertPostInitMethodToModuleSubClasses): + """ + A context manager to split the model into pipeline stages. + """ - def __init__(self): + def __init__(self, policy: str="balanced"): super().__init__() self._layer_spec_dict = {} self._root_children = None self._model = None self._layer_spec_list = [] self._func_dict = {} - self._policy = "balanced" + self._policy = policy @property def policy(self): return self._policy + @policy.setter + def policy(self, policy: str): + self._policy = policy + @property def layers_count(self): return len(self._layer_spec_list) @@ -30,10 +38,9 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): return len(self._func_dict) def _pre_context_exec(self): - """ + """ The Callback function when entering the context """ - # reserve rng states self.cpu_rng_state = torch.get_rng_state() self.cuda_rng_state = torch.cuda.get_rng_state() @@ -52,35 +59,50 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): The function to call at the end of the constructor of each module. NOTE() The module may be passed to this function multiple times. """ - module_id = id(module) + # iterate over the positional arguments + # to check if an argument is a torch Module + # if found any torch Module, replace it with its layer spec + # for storage purpose modified_args = [] - for obj in args: - if issubclass(obj.__class__, torch.nn.modules.module.Module): - obj = self._layer_spec_dict[id(obj)] - modified_args.append(obj) + for arg in args: + if isinstance(arg, torch.nn.Module): + arg = self._layer_spec_dict[id(arg)] + modified_args.append(arg) + # to the same for the keyword arguments modified_kwargs = {} for k, v in kwargs.items(): - if issubclass(v.__class__, torch.nn.modules.module.Module): + if isinstance(v, torch.nn.Module): v = self._layer_spec_dict[id(v)] # (lyl)TODO: analyse ColoTensor as well modified_kwargs[k] = v - modified_args = tuple(modified_args) + # keep track of the module children + # as torch.nn.Module.__init__ is called from inner module to outer module, + # the final value of self._model will be the outermost model + # e.g. if the model is torchvision.models.resnet18, then the final value of self._model + # will be the ``ResNet`` object. self._root_children = list(module.children()) self._model = module + + # store the children to keep the module hierarchy layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs) layer_spec.set_children(module.children()) + + # store the layer spec in this context + module_id = id(module) self._layer_spec_dict[module_id] = layer_spec + + # convert all torch.nn.Parameter to colossalai.tensor.ColoParameter name_list = [] for name, param in module.named_parameters(): - if isinstance(param, ColoTensor): + if isinstance(param, ColoParameter): continue name_list.append((name, param)) for name, param in name_list: delattr(module, name) - setattr(module, name, ColoTensor.from_torch_tensor(param)) + setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad)) def to_layer_list(self, exec_seq=None): """ @@ -100,7 +122,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): if id(module) == id(child_in_container): children_name.append(name) break - else: self._layer_spec_list.append(layer_spec) for name, module in self._model.named_modules(): @@ -110,10 +131,16 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): else: front_funcs_list = [] + named_modules = dict(self._model.named_modules()) for index, element in enumerate(exec_seq): if isinstance(element, str): - module = dict(self._model.named_modules())[element] + assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.' + + # get the layer spec based on the module ID + module = named_modules[element] layer_spec = self._layer_spec_dict[id(module)] + + # check whether there are functions which should be executed before this module if len(front_funcs_list) != 0: func_key = (layer_spec, "front") if func_key not in self._func_dict: @@ -121,6 +148,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): for f in front_funcs_list: self._func_dict[func_key].append(f) front_funcs_list = [] + func_key = (layer_spec, "behind") self._layer_spec_list.append(layer_spec) elif isinstance(element, tuple) and element[1] == "front": @@ -172,70 +200,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): return pipeline_model - def load_policy(self, policy): - self._policy = policy - - -def _build_kwargs_for_module(function, kw_dict): - """ - Generally, the first argument of module.forward is an input tensor come from the previous layer. - Therefore, we just filter the kwargs from second element of the dictionary. - """ - sig = inspect.signature(function) - if len(sig.parameters) <= 1: - return None - args_name_list = list(sig.parameters.keys()) - kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]} - return kw_dict - - -def _build_kwargs_for_function(function, kw_dict): - sig = inspect.signature(function) - kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters} - if len(kw_dict) == 0: - return None - return kw_dict - - -def _exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs): - """ - We suppose the callable object passed to to_layer_list method in two purpose: - a. use the callable object to modify input tensor, such as \ - lambda x: torch.flatten(x, 1) - b. use the callable object to modify kwargs value, such as \ - def foo(attention_mask=None): - if attention_mask is not None: - batch_size = input_ids.shape[0] - attention_mask = attention_mask.view(batch_size, -1) - return attention_mask - """ - - if kw_dict is not None: - rst = func(**kw_dict) - if isinstance(rst, tuple): - for i, k in enumerate(kw_dict.keys()): - kwargs[k] = rst[i] - else: - for k in kw_dict.keys(): - kwargs[k] = rst - return input_tensor - return func(input_tensor) - - -def _exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs): - - assert func_key in func_dict, f"{func_key} is not in the function_dict." - funcs_to_exec = func_dict[func_key] - if isinstance(funcs_to_exec, list): - for f in funcs_to_exec: - f_kwargs = _build_kwargs_for_function(f, kwargs) - input_tensor = _exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs) - else: - f_kwargs = _build_kwargs_for_function(funcs_to_exec, kwargs) - input_tensor = _exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs) - - return input_tensor - class PipelinableModel(torch.nn.Module): @@ -250,16 +214,16 @@ class PipelinableModel(torch.nn.Module): for module in self._module_list: if id(module) in self._front_func_dict: - input_tensor = _exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs) + input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs) if isinstance(module, CheckpointModule): forward_func = module._forward else: forward_func = module.forward if input_tensor is None: - module_kwargs = _build_kwargs_for_function(forward_func, kwargs) + module_kwargs = build_kwargs_for_function(forward_func, kwargs) else: - module_kwargs = _build_kwargs_for_module(forward_func, kwargs) + module_kwargs = build_kwargs_for_module(forward_func, kwargs) if module_kwargs is not None and input_tensor is not None: if isinstance(module, CheckpointModule): convert_kwargs_to_args = [] @@ -288,57 +252,9 @@ class PipelinableModel(torch.nn.Module): input_tensor = module(input_tensor) if id(module) in self._behind_func_dict: - input_tensor = _exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs) + input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs) return input_tensor -class LayerSpec: - def __init__(self, typename, *module_args, **module_kwargs): - self.typename = typename - self.module_args = module_args - self.module_kwargs = module_kwargs - self.children = None - self._param_count = 0 - - if not issubclass(typename, torch.nn.Module): - raise RuntimeError('LayerSpec only supports torch.nn.Module types.') - - def __repr__(self): - return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs) - - @property - def param_count(self): - return self._param_count - - def build(self): - """Build the stored specification.""" - - recovered_args = [] - for obj in self.module_args: - if isinstance(obj, LayerSpec): - obj = obj.build() - recovered_args.append(obj) - recovered_args = tuple(recovered_args) - - recovered_kwargs = {} - for k, v in self.module_kwargs.items(): - if isinstance(v, LayerSpec): - v = v.build() - recovered_kwargs[k] = v - - return self.typename(*recovered_args, **recovered_kwargs) - - def set_children(self, children): - self.children = children - - def count_params(self): - self._param_count = 0 - layer = self.build() - for param in layer.parameters(): - self._param_count += param.numel() - return self._param_count - - def reset_param_count(self): - self._param_count = 0 diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py new file mode 100644 index 000000000..6d1ea73d5 --- /dev/null +++ b/colossalai/pipeline/utils.py @@ -0,0 +1,207 @@ +import heapq +import inspect + +from colossalai.logging import get_dist_logger +from typing import List + +def _binary_partition(weights: List, start: int, end: int): + """Returns the binary partition position of `weights`, given the start + position `st` and the end position `ed`. + + Args: + weights (list): A python list to be binary partitioned + start (int): the start position of the binary partition + end (int): the end position of the binary partition + + Returns: + int: the binary partition position of `weights` + """ + w_sum = weights[end - 1] + prefix = 0 + if start > 0: + w_sum -= weights[start - 1] + prefix = weights[start - 1] + minimum = float("inf") + for idx in range(start + 1, end): + front = weights[idx - 1] - prefix + diff = abs(w_sum - 2 * front) + if diff < minimum: + pos = idx + minimum = diff + + return start, pos, end + + +def _heap_addition(weights: List, intervals: int, add_cnt: int): + """ + """ + + def _heap_push(heap, st, ed): + value = weights[ed - 1] + if st > 0: + value -= weights[st - 1] + heapq.heappush(heap, (-value, st, ed)) + + ret_intervals = [] + heap = [] + + for st, ed in intervals: + _heap_push(heap, st, ed) + + while add_cnt > 0: + _, st, ed = heapq.heappop(heap) + if ed - st == 1: + ret_intervals.append((st, ed)) + else: + l, m, r = _binary_partition(weights, st, ed) + _heap_push(heap, l, m) + _heap_push(heap, m, r) + add_cnt -= 1 + + while heap: + _, st, ed = heapq.heappop(heap) + ret_intervals.append((st, ed)) + + ret_intervals.sort() + return ret_intervals + + +def _calc_partitions(weights, value): + prev = 0 + prefix = 0 + num_block = 0 + intervals = [] + + for idx, w in enumerate(weights): + if weights[idx] - prefix > value: + intervals.append((prev, idx)) + prev = idx + prefix = weights[idx - 1] + num_block += 1 + + intervals.append((prev, len(weights))) + return num_block + 1, intervals + + +def _binary_search(weights, num): + length = len(weights) + prefix = [1 if w == 0 else w for w in weights] + for i in range(1, length): + prefix[i] += prefix[i - 1] + + lower_bound = max(weights) + upper_bound = prefix[length - 1] + + while upper_bound > lower_bound: + mid = (upper_bound + lower_bound) // 2 + number, _ = _calc_partitions(prefix, mid) + if number <= num: + upper_bound = mid + else: + lower_bound = mid + 1 + + num_block, intervals = _calc_partitions(prefix, upper_bound) + if num_block < num: + intervals = _heap_addition(prefix, intervals, num - num_block) + + return intervals + + +def partition_uniform(num_items, pipeline_parallel_size, num_chunks): + assert num_items % num_chunks == 0, \ + "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" + + logger = get_dist_logger() + parts = [[] for _ in range(pipeline_parallel_size)] + partition_items = num_items // num_chunks + for idx in range(num_chunks): + base_idx = idx * partition_items + chunk_size = partition_items // pipeline_parallel_size + left = pipeline_parallel_size - partition_items % pipeline_parallel_size + if chunk_size == 0: + logger.warning("Some nodes in Pipeline have no requests") + + for p in range(pipeline_parallel_size): + st = base_idx + base_idx += chunk_size + (p >= left) + parts[p].append((st, base_idx)) + + return parts + + +def partition_balanced(weights, pipeline_parallel_size, num_chunks): + num_total = pipeline_parallel_size * num_chunks + num_items = len(weights) + if num_items <= num_total: + return partition_uniform(num_items, pipeline_parallel_size, num_chunks) + + intervals = _binary_search(weights, num_total) + + current = 0 + parts = [[] for _ in range(pipeline_parallel_size)] + for inter in intervals: + parts[current].append(inter) + current = (current + 1) % pipeline_parallel_size + + return parts + + +def build_kwargs_for_module(function, kw_dict): + """ + Generally, the first argument of module.forward is an input tensor come from the previous layer. + Therefore, we just filter the kwargs from second element of the dictionary. + """ + sig = inspect.signature(function) + if len(sig.parameters) <= 1: + return None + args_name_list = list(sig.parameters.keys()) + kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]} + return kw_dict + + +def build_kwargs_for_function(function, kw_dict): + sig = inspect.signature(function) + kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters} + if len(kw_dict) == 0: + return None + return kw_dict + + +def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs): + """ + We suppose the callable object passed to to_layer_list method in two purpose: + a. use the callable object to modify input tensor, such as \ + lambda x: torch.flatten(x, 1) + b. use the callable object to modify kwargs value, such as \ + def foo(attention_mask=None): + if attention_mask is not None: + batch_size = input_ids.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + return attention_mask + """ + + if kw_dict is not None: + rst = func(**kw_dict) + if isinstance(rst, tuple): + for i, k in enumerate(kw_dict.keys()): + kwargs[k] = rst[i] + else: + for k in kw_dict.keys(): + kwargs[k] = rst + return input_tensor + return func(input_tensor) + + +def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs): + + assert func_key in func_dict, f"{func_key} is not in the function_dict." + funcs_to_exec = func_dict[func_key] + if isinstance(funcs_to_exec, list): + for f in funcs_to_exec: + f_kwargs = build_kwargs_for_function(f, kwargs) + input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs) + else: + f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs) + input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs) + + return input_tensor \ No newline at end of file diff --git a/tests/test_config/test_load_config.py b/tests/test_config/test_load_config.py index 235f95ad3..550af2a4a 100644 --- a/tests/test_config/test_load_config.py +++ b/tests/test_config/test_load_config.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest from colossalai.context.config import Config -from colossalai.builder import build_ophooks @pytest.mark.cpu diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py index 415b2ddc7..1429bcfdd 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -17,11 +17,14 @@ from colossalai.logging import get_dist_logger from colossalai.nn import CrossEntropyLoss from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.utils import is_using_pp, get_dataloader -from colossalai.utils.model.pipelinable import PipelinableContext +from colossalai.pipeline.pipelinable import PipelinableContext from tqdm import tqdm - -from titans.dataloader.cifar10 import build_cifar -from titans.model.vit import vit_tiny_patch4_32 +from torchvision.datasets import CIFAR10 +from torchvision.transforms import transforms +try: + from titans.model.vit import vit_tiny_patch4_32 +except: + pass BATCH_SIZE = 4 NUM_EPOCHS = 60 @@ -49,7 +52,14 @@ def run_trainer(rank, world_size, port): # craete dataloaders root = Path(os.environ['DATA']) - train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, pad_if_needed=True, crop=32, resize=32) + transform_train = transforms.Compose([ + transforms.RandomCrop(224, padding=4, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) # create loss function criterion = CrossEntropyLoss(label_smoothing=0.1) diff --git a/tests/test_utils/test_pipelinable.py b/tests/test_pipeline/test_pipelinable.py similarity index 94% rename from tests/test_utils/test_pipelinable.py rename to tests/test_pipeline/test_pipelinable.py index 09d815632..c99a88550 100644 --- a/tests/test_utils/test_pipelinable.py +++ b/tests/test_pipeline/test_pipelinable.py @@ -1,7 +1,7 @@ import torch import torch.multiprocessing as mp -from colossalai.utils.model.pipelinable import PipelinableContext +from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.testing import rerun_on_exception @@ -33,7 +33,7 @@ def run_pipelinable(rank): model = MLP() assert pipelinable.policy == "balanced" - pipelinable.load_policy("uniform") + pipelinable.policy = "uniform" assert pipelinable.policy == "uniform" pipelinable.to_layer_list() diff --git a/tests/test_trainer/test_pipeline/model/__init__.py b/tests/test_trainer/test_pipeline/model/__init__.py deleted file mode 100644 index 2bf880f41..000000000 --- a/tests/test_trainer/test_pipeline/model/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .layers import * -from .resnet import VanillaResNet diff --git a/tests/test_trainer/test_pipeline/model/layers/__init__.py b/tests/test_trainer/test_pipeline/model/layers/__init__.py deleted file mode 100644 index aa553b737..000000000 --- a/tests/test_trainer/test_pipeline/model/layers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .basic_block import ResNetBasicBlock -from .bottleneck import ResNetBottleneck -from .reslayer import ResLayer \ No newline at end of file diff --git a/tests/test_trainer/test_pipeline/model/layers/basic_block.py b/tests/test_trainer/test_pipeline/model/layers/basic_block.py deleted file mode 100644 index 320dac2fd..000000000 --- a/tests/test_trainer/test_pipeline/model/layers/basic_block.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Optional, Callable - -import torch.nn as nn -from torch import Tensor - -from colossalai.registry import LAYERS -from .conv import conv3x3 - - -@LAYERS.register_module -class ResNetBasicBlock(nn.Module): - """Basic ResNet block - """ - expansion: int = 1 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super().__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - if groups != 1 or base_width != 64: - raise ValueError( - 'BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError( - "Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out diff --git a/tests/test_trainer/test_pipeline/model/layers/bottleneck.py b/tests/test_trainer/test_pipeline/model/layers/bottleneck.py deleted file mode 100644 index d75f9534b..000000000 --- a/tests/test_trainer/test_pipeline/model/layers/bottleneck.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Optional, Callable - -import torch.nn as nn -from torch import Tensor - -from colossalai.registry import LAYERS -from .conv import conv3x3, conv1x1 - - -@LAYERS.register_module -class ResNetBottleneck(nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - - expansion: int = 4 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super().__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation) - self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) - self.bn3 = norm_layer(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out diff --git a/tests/test_trainer/test_pipeline/model/layers/conv.py b/tests/test_trainer/test_pipeline/model/layers/conv.py deleted file mode 100644 index c918d94c4..000000000 --- a/tests/test_trainer/test_pipeline/model/layers/conv.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch.nn as nn - - -def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - - -def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) diff --git a/tests/test_trainer/test_pipeline/model/layers/reslayer.py b/tests/test_trainer/test_pipeline/model/layers/reslayer.py deleted file mode 100644 index 4e1b48c5e..000000000 --- a/tests/test_trainer/test_pipeline/model/layers/reslayer.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch.nn as nn - -from colossalai.registry import LAYERS -from .conv import conv1x1 - - -@LAYERS.register_module -class ResLayer(nn.Module): - - def __init__(self, - block_type: str, - norm_layer_type: str, - inplanes: int, - planes: int, - blocks: int, - groups: int, - base_width: int, - stride: int = 1, - dilation: int = 1, - dilate: bool = False, - ): - super().__init__() - self.block = LAYERS.get_module(block_type) - self.norm_layer = LAYERS.get_module(norm_layer_type) - self.inplanes = inplanes - self.planes = planes - self.blocks = blocks - self.groups = groups - self.dilation = dilation - self.base_width = base_width - self.dilate = dilate - self.stride = stride - self.layer = self._make_layer() - - def _make_layer(self): - norm_layer = self.norm_layer - downsample = None - previous_dilation = self.dilation - if self.dilate: - self.dilation *= self.stride - self.stride = 1 - if self.stride != 1 or self.inplanes != self.planes * self.block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride), - norm_layer(self.planes * self.block.expansion), - ) - - layers = [] - layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups, - self.base_width, previous_dilation, norm_layer)) - self.inplanes = self.planes * self.block.expansion - for _ in range(1, self.blocks): - layers.append(self.block(self.inplanes, self.planes, groups=self.groups, - base_width=self.base_width, dilation=self.dilation, - norm_layer=norm_layer)) - - return nn.Sequential(*layers) - - def forward(self, x): - return self.layer(x) diff --git a/tests/test_trainer/test_pipeline/model/resnet.py b/tests/test_trainer/test_pipeline/model/resnet.py deleted file mode 100644 index 11d964943..000000000 --- a/tests/test_trainer/test_pipeline/model/resnet.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import List, Optional - -import torch -import torch.nn as nn -from torch import Tensor - -from colossalai.registry import LAYERS -from colossalai.registry import MODELS -from colossalai.nn.model import ModelFromConfig - - -@MODELS.register_module -class VanillaResNet(ModelFromConfig): - """ResNet from - `"Deep Residual Learning for Image Recognition" `_. - """ - - def __init__( - self, - num_cls: int, - block_type: str, - layers: List[int], - norm_layer_type: str = 'BatchNorm2d', - in_channels: int = 3, - groups: int = 1, - width_per_group: int = 64, - zero_init_residual: bool = False, - replace_stride_with_dilation: Optional[List[bool]] = None, - dilations=(1, 1, 1, 1) - ) -> None: - super().__init__() - - self.inplanes = 64 - self.zero_init_residual = zero_init_residual - self.blocks = layers - self.block_expansion = LAYERS.get_module(block_type).expansion - self.dilations = dilations - self.reslayer_common_cfg = dict( - type='ResLayer', - block_type=block_type, - norm_layer_type=norm_layer_type, - groups=groups, - base_width=width_per_group - ) - - if replace_stride_with_dilation is None: - # each element in the tuple indicates if we should replace - # the 2x2 stride with a dilated convolution instead - replace_stride_with_dilation = [False, False, False] - - if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) - - self.layers_cfg = [ - # conv1 - dict(type='Conv2d', - in_channels=in_channels, - out_channels=self.inplanes, - kernel_size=7, - stride=2, - padding=3, - bias=False), - # bn1 - dict( - type=norm_layer_type, - num_features=self.inplanes - ), - # relu - dict( - type='ReLU', - inplace=True - ), - # maxpool - dict( - type='MaxPool2d', - kernel_size=3, - stride=2, - padding=1 - ), - # layer 1 - dict( - inplanes=self.inplanes, - planes=64, - blocks=self.blocks[0], - dilation=self.dilations[0], - **self.reslayer_common_cfg - ), - # layer 2 - dict( - inplanes=64 * self.block_expansion, - planes=128, - blocks=self.blocks[1], - stride=2, - dilate=replace_stride_with_dilation[0], - dilation=self.dilations[1], - **self.reslayer_common_cfg - ), - # layer 3 - dict( - inplanes=128 * self.block_expansion, - planes=256, - blocks=layers[2], - stride=2, - dilate=replace_stride_with_dilation[1], - dilation=self.dilations[2], - **self.reslayer_common_cfg - ), - # layer 4 - dict( - inplanes=256 * self.block_expansion, - planes=512, - blocks=layers[3], stride=2, - dilate=replace_stride_with_dilation[2], - dilation=self.dilations[3], - **self.reslayer_common_cfg - ), - # avg pool - dict( - type='AdaptiveAvgPool2d', - output_size=(1, 1) - ), - # flatten - dict( - type='LambdaWrapper', - func=lambda mod, x: torch.flatten(x, 1) - ), - # linear - dict( - type='Linear', - in_features=512 * self.block_expansion, - out_features=num_cls - ) - ] - - def forward(self, x: Tensor): - for layer in self.layers: - x = layer(x) - return x - - def init_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_( - m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - if self.zero_init_residual: - for m in self.modules(): - if isinstance(m, LAYERS.get_module('ResNetBottleneck')): - # type: ignore[arg-type] - nn.init.constant_(m.bn3.weight, 0) - elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')): - # type: ignore[arg-type] - nn.init.constant_(m.bn2.weight, 0) diff --git a/tests/test_trainer/test_pipeline/resnet_config.py b/tests/test_trainer/test_pipeline/resnet_config.py deleted file mode 100644 index 59378b3dc..000000000 --- a/tests/test_trainer/test_pipeline/resnet_config.py +++ /dev/null @@ -1,21 +0,0 @@ -import os -import model -from pathlib import Path - -BATCH_SIZE = 128 -IMG_SIZE = 224 -DIM = 768 -NUM_CLASSES = 10 -NUM_ATTN_HEADS = 12 -NUM_MICRO_BATCHES = 2 - -# resnet 18 -model = dict(type='VanillaResNet', - block_type='ResNetBasicBlock', - layers=[2, 2, 2, 2], - num_cls=10) - -parallel = dict( - pipeline=dict(size=4), - tensor=dict(size=1, mode=None) -) diff --git a/tests/test_trainer/test_pipeline/test_partition.py b/tests/test_trainer/test_pipeline/test_partition.py deleted file mode 100644 index 88e3de352..000000000 --- a/tests/test_trainer/test_pipeline/test_partition.py +++ /dev/null @@ -1,43 +0,0 @@ -import os.path as osp - -import pytest -import torch -import torch.multiprocessing as mp - -from colossalai.builder.pipeline import build_pipeline_model_from_cfg -from colossalai.core import global_context -from colossalai.initialize import launch -from colossalai.logging import get_dist_logger -from functools import partial -from colossalai.utils import free_port -from colossalai.testing import rerun_on_exception - -DIR_PATH = osp.dirname(osp.realpath(__file__)) -CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py') - - -def run_partition(rank, world_size, port): - launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - logger = get_dist_logger() - logger.info('finished initialization') - - # build model - model = build_pipeline_model_from_cfg(global_context.config.model, 1, verbose=True) - assert isinstance(model, torch.nn.Module) - logger.info('model is created') - - global_context.destroy() - logger.info('training finished') - torch.cuda.empty_cache() - - -@pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_partition(): - world_size = 4 - run_func = partial(run_partition, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_partition() diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py index 88c669283..48f729658 100644 --- a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py +++ b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -8,27 +8,45 @@ from pathlib import Path import colossalai import pytest import torch +import torch.nn as nn import torch.multiprocessing as mp -from colossalai.builder import build_pipeline_model_from_cfg from colossalai.core import global_context as gpc -from colossalai.engine.schedule import PipelineSchedule +from colossalai.context import ParallelMode from colossalai.initialize import launch from colossalai.utils import free_port, get_dataloader, print_rank_0 from colossalai.testing import rerun_on_exception from torchvision import transforms from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 -BATCH_SIZE = 4 -DIR_PATH = osp.dirname(osp.realpath(__file__)) -CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py') +BATCH_SIZE = 8 +CONFIG=dict( + NUM_MICRO_BATCHES=2, + parallel = dict( + pipeline=dict(size=2), + tensor=dict(size=1, mode=None) + ) +) def run_schedule(rank, world_size, port): - launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model - model = build_pipeline_model_from_cfg(gpc.config.model, 1) + model = resnet18(num_classes=10) + + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2) + elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: + + class Flatten(nn.Module): + + def forward(self, x): + return torch.flatten(x, 1) + + model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) + print_rank_0('model is created') train_dataset = CIFAR10(root=Path(os.environ['DATA']), @@ -69,7 +87,7 @@ def run_schedule(rank, world_size, port): @pytest.mark.dist @rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") def test_pipeline_schedule(): - world_size = 4 + world_size = 2 run_func = partial(run_schedule, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py index 8a5ec409b..66deda871 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus def build_pipeline(model): - from colossalai.builder.pipeline import partition_uniform + from colossalai.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py index bd5c46237..beadb27cb 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus def build_pipeline(model): - from colossalai.builder.pipeline import partition_uniform + from colossalai.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py index 79dae487b..69bab53a6 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus def build_pipeline(model): - from colossalai.builder.pipeline import partition_uniform + from colossalai.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py index d2d938c04..ecbcb8630 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus def build_pipeline(model): - from colossalai.builder.pipeline import partition_uniform + from colossalai.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)