diff --git a/colossalai/__init__.py b/colossalai/__init__.py
index 854d941bc..e7ea7d65a 100644
--- a/colossalai/__init__.py
+++ b/colossalai/__init__.py
@@ -1,4 +1,4 @@
-from .initialize import init_dist, initialize
-from .nn import *
+from .initialize import (initialize, launch, launch_from_openmpi,
+ launch_from_slurm, launch_from_torch, get_default_parser)
__version__ = '0.0.1'
diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py
new file mode 100644
index 000000000..268eced66
--- /dev/null
+++ b/colossalai/amp/__init__.py
@@ -0,0 +1,32 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+from .amp_type import AMP_TYPE
+from colossalai.context import Config
+import torch.nn as nn
+from torch.optim import Optimizer
+from torch.nn.modules.loss import _Loss
+from .torch_amp import convert_to_torch_amp
+from .apex_amp import convert_to_apex_amp
+from .naive_amp import convert_to_naive_amp
+
+
+def convert_to_amp(model: nn.Module,
+ optimizer: Optimizer,
+ criterion: _Loss,
+ mode: AMP_TYPE,
+ amp_config: Config = None):
+ assert isinstance(mode, AMP_TYPE), \
+ f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
+
+ if amp_config is None:
+ amp_config = Config()
+
+ if mode == AMP_TYPE.TORCH:
+ model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
+ elif mode == AMP_TYPE.APEX:
+ model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
+ elif mode == AMP_TYPE.NAIVE:
+ model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
+
+ return model, optimizer, criterion
diff --git a/colossalai/engine/amp/amp_type.py b/colossalai/amp/amp_type.py
similarity index 83%
rename from colossalai/engine/amp/amp_type.py
rename to colossalai/amp/amp_type.py
index 7f7c5a659..6f322f866 100644
--- a/colossalai/engine/amp/amp_type.py
+++ b/colossalai/amp/amp_type.py
@@ -7,4 +7,4 @@ from enum import Enum
class AMP_TYPE(Enum):
APEX = 'apex'
TORCH = 'torch'
- PARALLEL = 'parallel'
+ NAIVE = 'naive'
diff --git a/colossalai/amp/apex_amp/__init__.py b/colossalai/amp/apex_amp/__init__.py
new file mode 100644
index 000000000..2d0ff9771
--- /dev/null
+++ b/colossalai/amp/apex_amp/__init__.py
@@ -0,0 +1,15 @@
+from .apex_amp import ApexAMPOptimizer
+import torch.nn as nn
+from torch.optim import Optimizer
+import apex.amp as apex_amp
+
+
+def convert_to_apex_amp(model: nn.Module,
+ optimizer: Optimizer,
+ amp_config):
+ model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
+ optimizer = ApexAMPOptimizer(optimizer)
+ return model, optimizer
+
+
+__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']
diff --git a/colossalai/amp/apex_amp/apex_amp.py b/colossalai/amp/apex_amp/apex_amp.py
new file mode 100644
index 000000000..d44478364
--- /dev/null
+++ b/colossalai/amp/apex_amp/apex_amp.py
@@ -0,0 +1,23 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import torch.nn as nn
+try:
+ import apex.amp as apex_amp
+except:
+ pass
+from torch import Tensor
+
+from colossalai.nn.optimizer import ColossalaiOptimizer
+from colossalai.utils import clip_grad_norm_fp32
+
+
+class ApexAMPOptimizer(ColossalaiOptimizer):
+
+ def backward(self, loss: Tensor):
+ with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
+ scaled_loss.backward()
+
+ def clip_grad_norm(self, model: nn.Module, max_norm: float):
+ if max_norm > 0:
+ clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)
diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py
new file mode 100644
index 000000000..e3a49c7e8
--- /dev/null
+++ b/colossalai/amp/naive_amp/__init__.py
@@ -0,0 +1,20 @@
+import torch.nn as nn
+from torch.optim import Optimizer
+from colossalai.utils import is_no_pp_or_last_stage
+
+from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
+
+
+def convert_to_naive_amp(model: nn.Module,
+ optimizer: Optimizer,
+ amp_config):
+ if is_no_pp_or_last_stage():
+ model = NaiveAMPModel(model, output_to_fp32=True)
+ else:
+ model = NaiveAMPModel(model, output_to_fp32=False)
+
+ optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
+ return model, optimizer
+
+
+__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer']
diff --git a/colossalai/nn/optimizer/fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py
similarity index 96%
rename from colossalai/nn/optimizer/fp16_optimizer.py
rename to colossalai/amp/naive_amp/_fp16_optimizer.py
index c64a732c9..d917a97bc 100644
--- a/colossalai/nn/optimizer/fp16_optimizer.py
+++ b/colossalai/amp/naive_amp/_fp16_optimizer.py
@@ -12,11 +12,9 @@ from torch.optim import Optimizer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.registry import OPTIMIZER_WRAPPERS
-from colossalai.utils import print_rank_0
-from ._utils import copy_tensor_parallel_attributes, clip_grad_norm_fp32, count_zeros_fp32
-from ..multi_tensor_apply import multi_tensor_applier
+from colossalai.logging import get_dist_logger
+from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
+ clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)
def _zero_grad_group_helper(group, set_to_none):
@@ -92,7 +90,7 @@ class DynamicGradScaler:
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
- self._logger = get_global_dist_logger()
+ self._logger = get_dist_logger()
@property
def scale(self):
@@ -113,7 +111,7 @@ class DynamicGradScaler:
if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor,
self.min_scale)
- self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}')
+ self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
@@ -125,10 +123,10 @@ class DynamicGradScaler:
# and scale up the loss scale.
if self._max_scale is not None and self._scale >= self._max_scale:
self._logger.info(
- f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed')
+ f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
else:
self._scale = self._scale * self.growth_factor
- self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}')
+ self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
def state_dict(self):
state_dict = {}
@@ -145,7 +143,6 @@ class DynamicGradScaler:
self._max_scale = state_dict['max_scale']
-@OPTIMIZER_WRAPPERS.register_module
class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types.
@@ -184,13 +181,13 @@ class FP16Optimizer(Optimizer):
max_scale: int = 2 ** 32):
# default args for compatibility
bf16 = False
- params_have_main_grad = False
+ params_have_main_grad = True
# have a defaults for compatibility with pytorch optim
self.defaults = optimizer.defaults
# log config
- self._logger = get_global_dist_logger()
+ self._logger = get_dist_logger()
self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
f"Optimizer: {optimizer.__class__.__name__}\n"
f"clip_grad = {clip_grad}\n"
@@ -328,6 +325,7 @@ class FP16Optimizer(Optimizer):
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
+
# For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups:
@@ -387,10 +385,6 @@ class FP16Optimizer(Optimizer):
@torch.no_grad()
def step(self):
- # for param_group in self.float16_groups:
- # for param in param_group:
- # print(param.grad is None)
-
# Copy gradients from model params to main params.
self._copy_model_grads_to_main_grads()
diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/amp/naive_amp/naive_amp.py
new file mode 100644
index 000000000..dd0b88b44
--- /dev/null
+++ b/colossalai/amp/naive_amp/naive_amp.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import Union, List, Any, Dict
+from torch.optim import Optimizer
+import torch.cuda.amp as torch_amp
+
+from colossalai.nn.optimizer import ColossalaiOptimizer
+from ._fp16_optimizer import FP16Optimizer
+
+
+class NaiveAMPOptimizer(ColossalaiOptimizer):
+
+ def __init__(self, optim: Optimizer, *args, **kwargs):
+ optim = FP16Optimizer(optimizer=optim, *args, **kwargs)
+ super().__init__(optim)
+
+ def backward(self, loss: Tensor):
+ loss = self.optim.scale_loss(loss)
+ loss.backward()
+
+ def step(self):
+ self.optim.step()
+
+ def clip_grad_norm(self, model: nn.Module, max_norm: float):
+ pass
+
+
+class NaiveAMPModel(nn.Module):
+
+ def __init__(self,
+ model: nn.Module,
+ output_to_fp32: bool = True):
+ super().__init__()
+ self.model = model.half()
+ self._output_to_fp32 = output_to_fp32
+
+ def _convert_to_fp16(self, input_: Any):
+ if isinstance(input_, Tensor) and input_.dtype == torch.float32:
+ input_ = input_.half()
+ return input_
+
+ def _convert_to_fp32(self, input_: Any):
+ if isinstance(input_, Tensor) and input_.dtype == torch.float16:
+ input_ = input_.float()
+ return input_
+
+ def forward(self, *args, **kwargs):
+ if args:
+ args = [self._convert_to_fp16(arg) for arg in args]
+ if kwargs:
+ for k, v in kwargs.items():
+ kwargs[k] = self._convert_to_fp16(v)
+
+ out = self.model(*args, **kwargs)
+
+ if self._output_to_fp32:
+ if isinstance(out, Tensor):
+ out = self._convert_to_fp32(out)
+ elif isinstance(out, (tuple, list)):
+ out = [self._convert_to_fp32(val) for val in out]
+ return out
diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/amp/torch_amp/__init__.py
new file mode 100644
index 000000000..b3c5b0c5b
--- /dev/null
+++ b/colossalai/amp/torch_amp/__init__.py
@@ -0,0 +1,18 @@
+import torch.nn as nn
+from torch.optim import Optimizer
+from torch.nn.modules.loss import _Loss
+from colossalai.context import Config
+from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
+
+
+def convert_to_torch_amp(model: nn.Module,
+ optimizer: Optimizer,
+ criterion: _Loss,
+ amp_config: Config):
+ model = TorchAMPModel(model)
+ optimizer = TorchAMPOptimizer(optimizer, **amp_config)
+ criterion = TorchAMPLoss(criterion)
+ return model, optimizer, criterion
+
+
+__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer']
diff --git a/colossalai/engine/amp/grad_scaler.py b/colossalai/amp/torch_amp/_grad_scaler.py
similarity index 99%
rename from colossalai/engine/amp/grad_scaler.py
rename to colossalai/amp/torch_amp/_grad_scaler.py
index 7859d132d..7e79ecab8 100644
--- a/colossalai/engine/amp/grad_scaler.py
+++ b/colossalai/amp/torch_amp/_grad_scaler.py
@@ -1,4 +1,8 @@
-# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
+# to support tensor parallel
+
import torch
from collections import defaultdict, abc
import warnings
diff --git a/colossalai/amp/torch_amp/torch_amp.py b/colossalai/amp/torch_amp/torch_amp.py
new file mode 100644
index 000000000..396360184
--- /dev/null
+++ b/colossalai/amp/torch_amp/torch_amp.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import torch.nn as nn
+import torch.cuda.amp as torch_amp
+
+from torch import Tensor
+from torch.nn.modules.loss import _Loss
+from torch.optim import Optimizer
+from ._grad_scaler import GradScaler
+
+from colossalai.nn.optimizer import ColossalaiOptimizer
+from colossalai.utils import clip_grad_norm_fp32
+
+
+class TorchAMPOptimizer(ColossalaiOptimizer):
+
+ def __init__(self, optim: Optimizer, *args, **kwargs):
+ super().__init__(optim)
+ self.scaler = GradScaler(*args, **kwargs)
+
+ def backward(self, loss: Tensor):
+ self.scaler.scale(loss).backward()
+
+ def step(self):
+ self.scaler.step(self.optim)
+ self.scaler.update()
+
+ def clip_grad_norm(self, model: nn.Module, max_norm: float):
+ if max_norm > 0.0:
+ self.scaler.unscale_(self.optim)
+ clip_grad_norm_fp32(model.parameters(), max_norm)
+
+
+class TorchAMPModel(nn.Module):
+
+ def __init__(self, model: nn.Module) -> None:
+ super().__init__()
+ self.model = model
+
+ @torch_amp.autocast()
+ def forward(self, *args, **kwargs):
+ return self.model(*args, **kwargs)
+
+
+class TorchAMPLoss(nn.Module):
+
+ def __init__(self, loss: _Loss):
+ super().__init__()
+ self.loss = loss
+
+ @torch_amp.autocast()
+ def forward(self, *args, **kwargs):
+ return self.loss(*args, **kwargs)
diff --git a/colossalai/builder/__init__.py b/colossalai/builder/__init__.py
index 2ae194132..6c1105a2d 100644
--- a/colossalai/builder/__init__.py
+++ b/colossalai/builder/__init__.py
@@ -1,10 +1,10 @@
-from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper,
- build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
+from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
+ build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler)
-from .pipeline import ModelInitializer
+from .pipeline import PipelineModelInitializer
__all__ = [
- 'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper',
+ 'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
- 'build_gradient_handler', 'ModelInitializer'
+ 'build_gradient_handler', 'PipelineModelInitializer'
]
diff --git a/colossalai/builder/builder.py b/colossalai/builder/builder.py
index c32ad3b39..6e8e24551 100644
--- a/colossalai/builder/builder.py
+++ b/colossalai/builder/builder.py
@@ -106,7 +106,7 @@ def build_dataset(config):
return build_from_registry(config, DATASETS)
-def build_optimizer(config, model, params: Iterable = None, need_module=False):
+def build_optimizer(config, model):
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
'model' and 'params'.
@@ -115,23 +115,12 @@ def build_optimizer(config, model, params: Iterable = None, need_module=False):
:type config: dict or :class:`colossalai.context.Config`
:param model: A model containing parameters for the optimizer
:type model: :class:`nn.Module`
- :param params: A dict containing parameters for the optimizer
- :type params: dict, optional
- :param need_module: Indicates whether the optimizer needs a module
- :type params: bool, optional
- :raises AssertionError: Raises an AssertionError if both `model` and `params` are None
:return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer`
"""
- assert model is not None or params is not None, 'arguments model and params can not both be None'
- if need_module:
- config['module'] = model
- elif model is not None:
- config['params'] = model.parameters()
- elif params is not None:
- config['params'] = params
-
- return build_from_registry(config, OPTIMIZERS)
+ config_ = config.copy()
+ config_['params'] = model.parameters()
+ return build_from_registry(config_, OPTIMIZERS)
def build_gradient_handler(config, model, optimizer):
@@ -149,8 +138,9 @@ def build_gradient_handler(config, model, optimizer):
:rtype: :class:`BaseGradientHandler`
"""
config_ = config.copy()
- mod_type = config_.pop('type')
- return GRADIENT_HANDLER.get_module(mod_type)(model, optimizer, **config_)
+ config_['model'] = model
+ config_['optimizer'] = optimizer
+ return build_from_registry(config_, GRADIENT_HANDLER)
def build_hooks(config, trainer):
@@ -164,8 +154,9 @@ def build_hooks(config, trainer):
:return: An object of :class:`BaseHook`
:rtype: :class:`BaseHook`
"""
- config['trainer'] = trainer
- return build_from_registry(config, HOOKS)
+ config_ = config.copy()
+ config_['trainer'] = trainer
+ return build_from_registry(config_, HOOKS)
def build_transform(config):
@@ -195,32 +186,8 @@ def build_data_sampler(config, dataset):
:rtype: :class:`colossalai.nn.data.sampler.BaseSampler`
"""
config_ = config.copy()
- mod_type = config_.pop('type')
- return SAMPLERS.get_module(mod_type)(dataset, **config_)
-
-
-def build_optimizer_wrapper(config, optimizer, model=None):
- """Returns an optimizer wrapper object of :class:`torch.optim.Optimizer` constructed
- from `config`, `model` and `optimizer`.
-
- :param config: A python dict or a :class:`colossalai.context.Config` object
- containing information used in the construction of the return object
- :type config: dict or :class:`colossalai.context.Config`
- :param optimizer: An optimizer object containing parameters for the gradient handler
- :type optimizer: :class:`torch.optim.Optimizer`
- :param model: A model containing parameters for the gradient handler
- :type model: :class:`nn.Module`, optional
- :return: An object of :class:`torch.optim.Optimizer`
- :rtype: :class:`torch.optim.Optimizer`
- """
- config_ = config.copy()
- mod_type = config_.pop('type')
-
- # LSG: special treatment for zeor level 3
- if mod_type == 'ZeroRedundancyOptimizer_Level_3':
- return OPTIMIZER_WRAPPERS.get_module(mod_type)(model, optimizer, **config_)
- else:
- return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
+ config_['dataset'] = dataset
+ return build_from_registry(config_, DATA_SAMPLERS)
def build_lr_scheduler(config, optimizer):
@@ -241,8 +208,8 @@ def build_lr_scheduler(config, optimizer):
:rtype: :class:`torch.optim.lr_scheduler`
"""
config_ = config.copy()
- mod_type = config_.pop('type')
- return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
+ config_['optimizer'] = optimizer
+ return build_from_registry(config_, LR_SCHEDULERS)
def build_schedule(config):
diff --git a/colossalai/builder/pipeline.py b/colossalai/builder/pipeline.py
index caf5c8472..5a568a909 100644
--- a/colossalai/builder/pipeline.py
+++ b/colossalai/builder/pipeline.py
@@ -4,7 +4,7 @@ import heapq
from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
from colossalai.utils import set_to_cuda
@@ -111,21 +111,21 @@ def _binary_search(weights, num):
return intervals
-def _partition_uniform(num_items, num_parts, num_chunks):
+def _partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
- logger = get_global_dist_logger()
- parts = [[] for _ in range(num_parts)]
+ logger = get_dist_logger()
+ parts = [[] for _ in range(pipeline_parallel_size)]
partition_items = num_items // num_chunks
for idx in range(num_chunks):
base_idx = idx * partition_items
- chunk_size = partition_items // num_parts
- left = num_parts - partition_items % num_parts
+ chunk_size = partition_items // pipeline_parallel_size
+ left = pipeline_parallel_size - partition_items % pipeline_parallel_size
if chunk_size == 0:
logger.warning("Some nodes in Pipeline have no requests")
- for p in range(num_parts):
+ for p in range(pipeline_parallel_size):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))
@@ -133,34 +133,34 @@ def _partition_uniform(num_items, num_parts, num_chunks):
return parts
-def _partition_balanced(weights, num_parts, num_chunks):
- num_total = num_parts * num_chunks
+def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
+ num_total = pipeline_parallel_size * num_chunks
num_items = len(weights)
if num_items <= num_total:
- return _partition_uniform(num_items, num_parts, num_chunks)
+ return _partition_uniform(num_items, pipeline_parallel_size, num_chunks)
intervals = _binary_search(weights, num_total)
current = 0
- parts = [[] for _ in range(num_parts)]
+ parts = [[] for _ in range(pipeline_parallel_size)]
for inter in intervals:
parts[current].append(inter)
- current = (current + 1) % num_parts
+ current = (current + 1) % pipeline_parallel_size
return parts
-class ModelInitializer():
+class PipelineModelInitializer():
def __init__(self, config, num_chunks, verbose=False):
self.num_chunks = num_chunks
self.ori_model = build_model(config)
self.layers = self.ori_model.layers_cfg
layer_length = len(self.layers)
self.verbose = verbose
- self._logger = get_global_dist_logger()
+ self._logger = get_dist_logger()
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
- def model_initialize(self, partition_method='parameter'):
+ def initialize(self, partition_method='parameter'):
# Some space for initializing comunication groups
self._interval = None
self._partition_layers(method=partition_method)
@@ -198,7 +198,7 @@ class ModelInitializer():
for st, ed in self.parts[stage]:
for idx, layer in enumerate(self.layers[st: ed]):
log_str += f'\t{idx + st:2d}: {layer}\n'
- self._logger.info(log_str)
+ self._logger.info(log_str, ranks=[0])
# Save the partition
self._interval = self.parts[pipeline_rank]
diff --git a/colossalai/communication/__init__.py b/colossalai/communication/__init__.py
index 4241bff4b..5da045326 100644
--- a/colossalai/communication/__init__.py
+++ b/colossalai/communication/__init__.py
@@ -1,4 +1,4 @@
-from .collective import all_gather, reduce_scatter, scatter
+from .collective import all_gather, reduce_scatter, all_reduce
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
send_backward, send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
@@ -6,7 +6,7 @@ from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta
__all__ = [
- 'all_gather', 'reduce_scatter', 'scatter',
+ 'all_gather', 'reduce_scatter', 'all_reduce',
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py
index 6db799c99..5778028ea 100644
--- a/colossalai/communication/collective.py
+++ b/colossalai/communication/collective.py
@@ -11,7 +11,7 @@ from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int,
- parallel_mode: ParallelMode) -> Tensor:
+ parallel_mode: ParallelMode, async_op=False) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension.
@@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int,
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
- shape = list(temp.shape)
- shape[dim] *= depth
- out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
- out = list(torch.chunk(out, depth, dim=dim))
- out = [val.contiguous() for val in out]
- dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
- out = torch.cat(out, dim=dim)
- return out
+ # shape = list(temp.shape)
+ # shape[dim] *= depth
+ # out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
+ # out = list(torch.chunk(out, depth, dim=dim))
+ # out = [val.contiguous() for val in out]
+ shape = [1] * len(tensor.shape)
+ shape[dim] = depth
+ out = tensor.repeat(shape)
+ out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
+ op = dist.all_gather(tensor_list=out,
+ tensor=temp,
+ group=gpc.get_group(parallel_mode),
+ async_op=async_op)
+ # out = torch.cat(out, dim=dim)
+ if async_op:
+ return out, op
+ else:
+ return out
def reduce_scatter(tensor: Tensor, dim: int,
- parallel_mode: ParallelMode) -> Tensor:
+ parallel_mode: ParallelMode, async_op=False) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.
@@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int,
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
- temp = list(torch.chunk(tensor, depth, dim=dim))
- temp = [val.contiguous() for val in temp]
- out = torch.empty(temp[0].shape,
- dtype=temp[0].dtype,
- device=get_current_device())
- dist.reduce_scatter(output=out,
- input_list=temp,
- group=gpc.get_group(parallel_mode))
- return out
+ # temp = list(torch.chunk(tensor, depth, dim=dim))
+ # temp = [val.contiguous() for val in temp]
+ # out = torch.zeros(temp[0].shape,
+ # dtype=temp[0].dtype,
+ # device=get_current_device())
+ temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
+ out = temp[0].clone()
+ op = dist.reduce_scatter(output=out,
+ input_list=temp,
+ group=gpc.get_group(parallel_mode),
+ async_op=async_op)
+ if async_op:
+ return out, op
+ else:
+ return out
-def scatter(tensor: Tensor, src: int, dim: int,
- parallel_mode: ParallelMode) -> Tensor:
- """Scatters in a specific dimension from source rank to all ranks in
- the parallel group.
+def all_reduce(tensor: Tensor,
+ parallel_mode: ParallelMode,
+ async_op=False) -> Tensor:
+ op = dist.all_reduce(tensor,
+ group=gpc.get_group(parallel_mode),
+ async_op=async_op)
+ if async_op:
+ return tensor, op
+ else:
+ return tensor
+
+
+# def scatter(tensor: Tensor, src: int, dim: int,
+# parallel_mode: ParallelMode) -> Tensor:
+# """Scatters in a specific dimension from source rank to all ranks in
+# the parallel group.
- :param tensor: Tensor to be scattered
- :param dim: The dimension scattering in
- :param parallel_mode: Parallel group mode used in this communication
- :type tensor: Tensor
- :type dim: int
- :type parallel_mode: ParallelMode
- :return: The tensor generated by scatter
- :rtype: Tensor
- """
- depth = gpc.get_world_size(parallel_mode)
- temp = tensor.clone()
- dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
- rank = gpc.get_local_rank(parallel_mode)
- out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
- return out
+# :param tensor: Tensor to be scattered
+# :param dim: The dimension scattering in
+# :param parallel_mode: Parallel group mode used in this communication
+# :type tensor: Tensor
+# :type dim: int
+# :type parallel_mode: ParallelMode
+# :return: The tensor generated by scatter
+# :rtype: Tensor
+# """
+# depth = gpc.get_world_size(parallel_mode)
+# temp = tensor.clone()
+# dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
+# rank = gpc.get_local_rank(parallel_mode)
+# out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
+# return out
diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py
index 7e761e180..3eb94ac60 100644
--- a/colossalai/communication/p2p.py
+++ b/colossalai/communication/p2p.py
@@ -17,8 +17,6 @@ def _communicate(tensor_send_next=None,
recv_next_shape=None,
prev_rank=None,
next_rank=None,
- up_group=None,
- down_group=None,
dtype=None):
"""
Adapted from megatron.p2p_communication.
@@ -59,60 +57,44 @@ def _communicate(tensor_send_next=None,
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
- if up_group is None:
- up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
if tensor_send_next is not None or recv_next:
if next_rank is None:
next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE)
- if down_group is None:
- down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
# rank = dist.get_rank()
rank = gpc.get_global_rank()
ops = []
if tensor_send_prev is not None:
- send_prev_op = dist.broadcast(tensor_send_prev,
- src=rank,
- group=up_group,
- async_op=True)
+ send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
- recv_prev_op = dist.broadcast(tensor_recv_prev,
- src=prev_rank,
- group=up_group,
- async_op=True)
+ recv_prev_op = dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)
ops.append(recv_prev_op)
if tensor_recv_next is not None:
- recv_next_op = dist.broadcast(tensor_recv_next,
- src=next_rank,
- group=down_group,
- async_op=True)
+ recv_next_op = dist.P2POp(dist.irecv, tensor_recv_next, next_rank)
ops.append(recv_next_op)
if tensor_send_next is not None:
- send_next_op = dist.broadcast(tensor_send_next,
- src=rank,
- group=down_group,
- async_op=True)
+ send_next_op = dist.P2POp(dist.isend, tensor_send_next, next_rank)
ops.append(send_next_op)
- for req in ops:
- req.wait()
+ if len(ops) > 0:
+ reqs = dist.batch_isend_irecv(ops)
+ for req in reqs:
+ req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next
-def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
+def recv_forward(input_tensor_shape, prev_rank=None):
"""Receives the input tensor from the previous member in pipeline.
-
+
:param input_tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
- :param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_shape: torch.Size
:type prev_rank: int, optional
- :type up_group: ProcessGroup, optional
:return: The input tensor in forward step
:rtype: Tensor
"""
@@ -121,20 +103,17 @@ def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
else:
input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape,
- prev_rank=prev_rank,
- up_group=up_group)
+ prev_rank=prev_rank)
return input_tensor
-def recv_backward(output_grad_shape, next_rank=None, down_group=None):
+def recv_backward(output_grad_shape, next_rank=None):
"""Receives the grad tensor from the next member in pipeline.
-
+
:param output_grad_shape: The shape of the tensor to be recieved
:param next_rank: The rank of the source of the tensor
- :param down_group: Communication group including the next member in pipeline parallel group
:type output_grad_shape: torch.Size
:type next_rank: int, optional
- :type down_group: ProcessGroup, optional
:return: The grad of output tensor in forward step
:rtype: Tensor
"""
@@ -143,56 +122,44 @@ def recv_backward(output_grad_shape, next_rank=None, down_group=None):
else:
_, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape,
- next_rank=next_rank,
- down_group=down_group)
+ next_rank=next_rank)
return output_tensor_grad
-def send_forward(output_tensor,
- next_rank=None,
- down_group=None):
+def send_forward(output_tensor, next_rank=None):
"""Sends the input tensor to the next member in pipeline.
-
+
:param output_tensor: Tensor to be sent
:param next_rank: The rank of the recipient of the tensor
- :param down_group: Communication group including the next member in pipeline parallel group
:type output_tensor: Tensor
:type next_rank: int, optional
- :type down_group: ProcessGroup, optional
"""
if not gpc.is_last_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_next=output_tensor,
- next_rank=next_rank,
- down_group=down_group)
+ next_rank=next_rank)
-def send_backward(input_tensor_grad,
- prev_rank=None,
- up_group=None):
+def send_backward(input_tensor_grad, prev_rank=None):
"""Sends the grad tensor to the previous member in pipeline.
-
+
:param input_tensor_grad: Tensor to be sent
:param prev_rank: The rank of the recipient of the tensor
- :param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_grad: Tensor
:type prev_rank: int, optional
- :type up_group: ProcessGroup, optional
"""
if not gpc.is_first_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_prev=input_tensor_grad,
- prev_rank=prev_rank,
- up_group=up_group)
+ prev_rank=prev_rank)
def send_forward_recv_backward(output_tensor,
output_grad_shape,
recv_next=True,
- next_rank=None,
- down_group=None):
+ next_rank=None):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.
-
+
:param output_tensor: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor
@@ -206,20 +173,18 @@ def send_forward_recv_backward(output_tensor,
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
- next_rank=next_rank,
- down_group=down_group)
+ next_rank=next_rank)
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape,
recv_prev=True,
- prev_rank=None,
- up_group=None):
+ prev_rank=None):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.
-
+
:param input_tensor_grad: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor
@@ -233,8 +198,7 @@ def send_backward_recv_forward(input_tensor_grad,
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
- prev_rank=prev_rank,
- up_group=up_group)
+ prev_rank=prev_rank)
return input_tensor
@@ -242,13 +206,11 @@ def send_forward_recv_forward(output_tensor,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
- next_rank=None,
- up_group=None,
- down_group=None):
+ next_rank=None):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.
-
+
:param output_tensor: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor
@@ -260,9 +222,7 @@ def send_forward_recv_forward(output_tensor,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
- next_rank=next_rank,
- up_group=up_group,
- down_group=down_group)
+ next_rank=next_rank)
return input_tensor
@@ -270,13 +230,11 @@ def send_backward_recv_backward(input_tensor_grad,
output_grad_shape,
recv_next=True,
prev_rank=None,
- next_rank=None,
- up_group=None,
- down_group=None):
+ next_rank=None):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.
-
+
:param input_tensor_grad: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor
@@ -288,9 +246,7 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
- next_rank=next_rank,
- up_group=up_group,
- down_group=down_group)
+ next_rank=next_rank)
return output_tensor_grad
@@ -301,13 +257,11 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev=True,
recv_next=True,
prev_rank=None,
- next_rank=None,
- up_group=None,
- down_group=None):
+ next_rank=None):
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.
-
+
:param output_tensor: Tensor sent to the next
:param input_tensor_grad: Tensor sent to the previous
:param input_tensor_shape: The shape of the tensor recieved from the previous
@@ -327,7 +281,5 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
- next_rank=next_rank,
- up_group=up_group,
- down_group=down_group)
+ next_rank=next_rank)
return input_tensor, output_tensor_grad
diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py
index d6d7dc091..a8dc0da1a 100644
--- a/colossalai/communication/utils.py
+++ b/colossalai/communication/utils.py
@@ -6,7 +6,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
-def send_tensor_meta(tensor, need_meta=True, down_group=None):
+def send_tensor_meta(tensor, need_meta=True, next_rank=None):
"""Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function
@@ -14,31 +14,34 @@ def send_tensor_meta(tensor, need_meta=True, down_group=None):
:param tensor: Tensor to be sent
:param need_meta: If False, meta information won't be sent
- :param down_group: Communication group including the next member in pipeline parallel group
+ :param next_rank: The rank of the next member in pipeline parallel group
:type tensor: Tensor
:type need_meta: bool, optional
- :type down_group: ProcessGroup, optional
+ :type next_rank: int
:return: False
:rtype: bool
"""
if need_meta:
- rank = gpc.get_global_rank()
-
- if down_group is None:
- down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
+ if next_rank is None:
+ next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
-
- dist.broadcast(send_ndims, src=rank, group=down_group)
- dist.broadcast(send_shape, src=rank, group=down_group)
+ ops = [
+ dist.P2POp(dist.isend, send_ndims, next_rank),
+ dist.P2POp(dist.isend, send_shape, next_rank)
+ ]
+ reqs = dist.batch_isend_irecv(ops)
+ for req in reqs:
+ req.wait()
+ torch.cuda.synchronize()
return False
-def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
+def recv_tensor_meta(tensor_shape, prev_rank=None):
"""Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function
@@ -46,27 +49,21 @@ def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
:param tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
- :param up_group: Communication group including the previous member in pipeline parallel group
:type tensor_shape: torch.Size
:type prev_rank: int, optional
- :type up_group: ProcessGroup, optional
:return: The shape of the tensor to be recieved
:rtype: torch.Size
"""
if tensor_shape is None:
if prev_rank is None:
- prev_rank = gpc.get_prev_global_rank(
- ParallelMode.PIPELINE)
- if up_group is None:
- up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
+ prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
recv_ndims = torch.empty((), **tensor_kwargs)
- dist.broadcast(recv_ndims, src=prev_rank, group=up_group)
-
+ dist.recv(recv_ndims, prev_rank)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
- dist.broadcast(recv_shape, src=prev_rank, group=up_group)
+ dist.recv(recv_shape, prev_rank)
tensor_shape = torch.Size(recv_shape)
diff --git a/colossalai/constants.py b/colossalai/constants.py
index 073dd2d2a..874c53d72 100644
--- a/colossalai/constants.py
+++ b/colossalai/constants.py
@@ -25,7 +25,11 @@ TESSERACT_DEP = 'TESSERACT_DEP'
# 3D parallel
DEPTH_3D = 'DEPTH_3D'
+INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
+WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
+OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'
# Tensor parallel attributes
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
-TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL]
+NUM_PARTITIONS = 'num_partitions'
+TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py
index 3009779c8..ac1408773 100644
--- a/colossalai/context/__init__.py
+++ b/colossalai/context/__init__.py
@@ -1,5 +1,5 @@
-from .config import Config
+from .config import Config, ConfigException
from .parallel_context import ParallelContext
-from .parallel_context import ParallelMode
+from .parallel_mode import ParallelMode
from .process_group_initializer import *
from .random import *
diff --git a/colossalai/context/_utils.py b/colossalai/context/_utils.py
deleted file mode 100644
index a770ea7b4..000000000
--- a/colossalai/context/_utils.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import math
-
-
-def set_parallel_size(obj, config: dict, key: str, attr_name: str):
- if key in config:
- ele = config[key]
- if isinstance(ele, int):
- setattr(obj, attr_name, ele)
- elif isinstance(ele, dict):
- setattr(obj, attr_name, ele['size'])
- else:
- raise NotImplementedError(
- f"Parallel configuration does not support this kind of argument, please use int or dict"
- )
-
-
-def add_tensor_pg(pg_init, mode, size, depth=None):
- if mode == '1d':
- pg_init.append(dict(
- type='Initializer1D',
- parallel_size=size
- ))
- elif mode == '2d':
- dim = math.floor(math.sqrt(size))
- pg_init.append(dict(
- type='Initializer2D_Col',
- summa_dim=dim
- ))
- pg_init.append(dict(
- type='Initializer2D_Row',
- summa_dim=dim
- ))
- elif mode == '2.5d':
- dim = math.floor(math.sqrt(size // depth))
- pg_init.append(dict(
- type='Initializer_Tesseract_ROW',
- tesseract_dim=dim,
- tesseract_dep=depth
- ))
- pg_init.append(dict(
- type='Initializer_Tesseract_COL',
- tesseract_dim=dim,
- tesseract_dep=depth
- ))
- pg_init.append(dict(
- type='Initializer_Tesseract_DEP',
- tesseract_dim=dim,
- tesseract_dep=depth
- ))
- pg_init.append(dict(
- type='Initializer_Tesseract_XZ',
- tesseract_dim=dim,
- tesseract_dep=depth
- ))
- elif mode == '3d':
- dim = math.floor(math.pow(size, 1.0 / 3.0) + 0.5)
- pg_init.append(dict(
- type='ParallelInitializer3D_Input',
- depth=dim
- ))
- pg_init.append(dict(
- type='ParallelInitializer3D_Weight',
- depth=dim
- ))
- pg_init.append(dict(
- type='ParallelInitializer3D_Output',
- depth=dim
- ))
- else:
- raise NotImplementedError("This kind of tensor splitting has not been implemented yet")
diff --git a/colossalai/context/config.py b/colossalai/context/config.py
index 52a375aa1..5943aa7ed 100644
--- a/colossalai/context/config.py
+++ b/colossalai/context/config.py
@@ -97,3 +97,7 @@ class Config(dict):
sys.path.pop(0)
return config
+
+
+class ConfigException(Exception):
+ pass
diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py
index 5a7a0bfb9..4f8e9f807 100644
--- a/colossalai/context/parallel_context.py
+++ b/colossalai/context/parallel_context.py
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-import os
import random
from typing import Union
@@ -11,8 +10,8 @@ import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config
+from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER
-from ._utils import set_parallel_size
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
@@ -21,11 +20,24 @@ class ParallelContext:
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
- :param args: The distributed arguments in the system
- :type args: dict
"""
- def __init__(self, args=None):
+ __instance = None
+
+ @staticmethod
+ def get_instance():
+ if ParallelContext.__instance is None:
+ ParallelContext()
+ return ParallelContext.__instance
+
+ def __init__(self):
+ # create a singleton instance
+ if ParallelContext.__instance is not None:
+ raise Exception(
+ 'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
+ else:
+ ParallelContext.__instance = self
+
# distributed settings
self._global_ranks = dict()
self._local_ranks = dict()
@@ -34,7 +46,6 @@ class ParallelContext:
self._ranks_in_group = dict()
# load config from file
- self._dist_args = args
self._config = None
# default 3D parallel args, will be overwritten during process group intialization
@@ -43,10 +54,22 @@ class ParallelContext:
self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1
+ # logging
+ self._verbose = False
+ self._logger = get_dist_logger()
+
@property
def config(self):
return self._config
+ @property
+ def verbose(self):
+ return self._verbose
+
+ @verbose.setter
+ def verbose(self, verbose_: bool):
+ self._verbose = verbose_
+
def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file.
@@ -62,14 +85,6 @@ class ParallelContext:
else:
raise TypeError("Invalid type for config, only dictionary or string is supported")
- def set_dist_args(self, args):
- """Sets the distributed arguments.
-
- :param args: The distributed arguments in the system
- :type args: dict
- """
- self._dist_args = args
-
@staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode)
@@ -268,32 +283,36 @@ class ParallelContext:
self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks
- def init_global_dist(self, addr=None, port=None):
- """Initializes the global distributed environment.
-
- :param addr: The IP address of the current device
- :type addr: str, optional
- :param port: The port to be used in the system of the current device
- :type port: int, optional
+ def init_global_dist(self,
+ rank: int,
+ world_size: int,
+ backend: str,
+ host: str,
+ port: int
+ ):
+ """Initializes the global distributed environment
+ :param rank: rank for the default process group
+ :type rank: int
+ :param world_size: world size of the default process group
+ :type world_size: int
+ :param host: the master address for distributed training
+ :type host: str
+ :param port: the master port for distributed training
+ :type port: str
+ :param backend: backend for torch.distributed
+ :type backend: str
"""
- # get config
- rank = self._dist_args.local_rank
- world_size = self._dist_args.world_size
- # default env config, overwrite by exporting
- # them in your bash script
- addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr
- port = os.getenv('MASTER_PORT', '8008') if port is None else port
- init_method = f'tcp://{addr}:{port}'
-
- dist.init_process_group(backend=self._dist_args.backend,
- rank=rank,
+ # initialize the default process group
+ init_method = f'tcp://{host}:{port}'
+ dist.init_process_group(rank=rank,
world_size=world_size,
+ backend=backend,
init_method=init_method)
# None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None,
list(range(world_size)), ParallelMode.GLOBAL)
- self._global_ranks[ParallelMode.GLOBAL] = rank
+ self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size,
process_group, ranks_in_group, mode):
@@ -312,7 +331,20 @@ class ParallelContext:
pps = self.pipeline_parallel_size
tps = self.tensor_parallel_size
ws = self.world_size
- assert ws == dps * pps * tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
+ assert ws == dps * pps * \
+ tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
+
+ def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
+ if key in config:
+ ele = config[key]
+ if isinstance(ele, int):
+ setattr(self, attr_name, ele)
+ elif isinstance(ele, dict):
+ setattr(self, attr_name, ele['size'])
+ else:
+ raise NotImplementedError(
+ f"Parallel configuration does not support this kind of argument, please use int or dict"
+ )
def init_parallel_groups(self):
"""Initializes the parallel groups.
@@ -325,21 +357,20 @@ class ParallelContext:
world_size = self.get_world_size(ParallelMode.GLOBAL)
self.world_size = world_size
- assert hasattr(self.config, 'parallel'), 'Expected the field parallel to be present in the config file'
-
# set parallel size as attributes for global context
- parallel_config = self.config.parallel
- set_parallel_size(self, parallel_config, 'pipeline',
- 'pipeline_parallel_size')
- set_parallel_size(self, parallel_config, 'tensor',
- 'tensor_parallel_size')
+ parallel_config = self.config.get('parallel', None)
+ if parallel_config is not None:
+ self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size')
+ self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size')
# the user should not set the data parallel size manually
# instead, it should be calculated based on other parallel config
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
# get the tensor parallel mode and check
- tensor_parallel_mode = parallel_config['tensor'].get('mode', None)
+ tensor_parallel_mode = None
+ if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
+ tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
self.check_sanity()
@@ -400,23 +431,21 @@ class ParallelContext:
# destroy global process group
dist.destroy_process_group()
- def set_device(self):
+ def set_device(self, device_ordinal: int = None):
"""Sets distributed processes to be bound to devices.
"""
- devices_per_node = torch.cuda.device_count()
global_rank = self.get_global_rank()
- device = global_rank % devices_per_node
- torch.cuda.set_device(device)
- print(f'process rank {global_rank} is bound to device {device}')
+ if device_ordinal is None:
+ devices_per_node = torch.cuda.device_count()
+ device_ordinal = global_rank % devices_per_node
- def set_seed(self):
+ torch.cuda.set_device(device_ordinal)
+ if self._verbose:
+ self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}')
+
+ def set_seed(self, seed: int):
"""Sets seeds for all random libraries.
"""
- if hasattr(self.config, 'seed'):
- seed = getattr(self.config, 'seed')
- else:
- seed = 2 # default seed
-
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@@ -444,11 +473,18 @@ class ParallelContext:
seeds = get_seeds()
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
- print(f"initialized seed on rank {global_rank}, "
- f"numpy: {seed}, python random: {seed}, {seed_str},"
- f"the default parallel seed is {ParallelMode.DATA}.", flush=True)
+ if self._verbose:
+ self._logger.info(
+ f"initialized seed on rank {global_rank}, "
+ f"numpy: {seed}, python random: {seed}, {seed_str},"
+ f"the default parallel seed is {ParallelMode.DATA}.",
+ ranks=[0])
else:
- print(f"initialized seed on rank {global_rank}, "
- f"numpy: {seed}, python random: {seed}, pytorch: {seed}", flush=True)
- print('WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
- flush=True)
+ if self._verbose:
+ self._logger.info(
+ f"initialized seed on rank {global_rank}, "
+ f"numpy: {seed}, python random: {seed}, pytorch: {seed}",
+ ranks=[0])
+ self._logger.info(
+ 'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
+ ranks=[0])
diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/context/process_group_initializer/initializer_1d.py
index 784480a72..1b487aba1 100644
--- a/colossalai/context/process_group_initializer/initializer_1d.py
+++ b/colossalai/context/process_group_initializer/initializer_1d.py
@@ -4,7 +4,6 @@
import torch.distributed as dist
from colossalai.context import Config
-from colossalai.core import global_context as gpc
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/context/process_group_initializer/initializer_2p5d.py
index cacfdc590..ab8fe3573 100644
--- a/colossalai/context/process_group_initializer/initializer_2p5d.py
+++ b/colossalai/context/process_group_initializer/initializer_2p5d.py
@@ -8,7 +8,6 @@ import torch.distributed as dist
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
from colossalai.context import Config
-from colossalai.core import global_context as gpc
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@@ -42,8 +41,6 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
tesseract_dep: int,
*args):
super(Initializer_2p5D_ROW, self).__init__(*args)
-
- self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
@@ -66,7 +63,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
for j in range(self.tesseract_dim):
for k in range(self.tesseract_dep):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
- j + self.tesseract_dim * k) for i in range(self.tesseract_dim)]
+ j + self.tesseract_dim * k) for i in range(self.tesseract_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:
@@ -81,13 +78,12 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
class Initializer_2p5D_Col(ProcessGroupInitializer):
'''2p5d tensor parallel initialization among cols.
'''
+
def __init__(self,
tesseract_dim: int,
tesseract_dep: int,
*args):
super(Initializer_2p5D_Col, self).__init__(*args)
-
- self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
@@ -110,7 +106,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
for i in range(self.tesseract_dim):
for k in range(self.tesseract_dep):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
- j + self.tesseract_dim * k) for j in range(self.tesseract_dim)]
+ j + self.tesseract_dim * k) for j in range(self.tesseract_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:
@@ -125,13 +121,12 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
class Initializer_2p5D_Dep(ProcessGroupInitializer):
'''2p5D tensor parallel initialization among depths.
'''
+
def __init__(self,
tesseract_dim: int,
tesseract_dep: int,
*args):
super(Initializer_2p5D_Dep, self).__init__(*args)
-
- self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
@@ -154,7 +149,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
for i in range(self.tesseract_dim):
for j in range(self.tesseract_dim):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
- j + self.tesseract_dim * k) for k in range(self.tesseract_dep)]
+ j + self.tesseract_dim * k) for k in range(self.tesseract_dep)]
group = dist.new_group(ranks)
if self.rank in ranks:
@@ -170,13 +165,12 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
class Initializer_2p5D_XZ(ProcessGroupInitializer):
'''2p5d tensor parallel initialization among cols times dep.
'''
+
def __init__(self,
tesseract_dim: int,
tesseract_dep: int,
*args):
super(Initializer_2p5D_XZ, self).__init__(*args)
-
- self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
@@ -198,8 +192,8 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
for h in range(self.num_group):
for i in range(self.tesseract_dim):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
- j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in
- range(self.tesseract_dim)]
+ j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in
+ range(self.tesseract_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:
diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py
index 391230767..464049193 100644
--- a/colossalai/context/process_group_initializer/initializer_3d.py
+++ b/colossalai/context/process_group_initializer/initializer_3d.py
@@ -5,7 +5,7 @@ import math
import os
import torch.distributed as dist
-from colossalai.constants import DEPTH_3D
+from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
@@ -18,7 +18,7 @@ def _check_depth_env_var(depth):
if env_depth:
assert int(env_depth) == depth, \
- 'SUMMA_DIM has been set in the current environment and ' \
+ 'DEPTH_3D has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[DEPTH_3D] = str(depth)
@@ -43,6 +43,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT
+ os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D
for h in range(self.num_group):
for i in range(self.depth):
@@ -82,6 +83,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_WEIGHT
+ os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D
for h in range(self.num_group):
for k in range(self.depth):
@@ -121,6 +123,7 @@ class Initializer_3D_Output(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT
+ os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D
for h in range(self.num_group):
for i in range(self.depth):
diff --git a/colossalai/core.py b/colossalai/core.py
index 39453e4a0..ff3034791 100644
--- a/colossalai/core.py
+++ b/colossalai/core.py
@@ -3,14 +3,4 @@
from colossalai.context import ParallelContext
-global_context = ParallelContext()
-
-
-def set_global_context(context: ParallelContext):
- '''Reset global context to be identical to a given :class:ParallelContext.
-
- :param context: Parallel context to generate our global parallel context.
- :type context: ParallelContext
- '''
- global global_context
- global_context = context
+global_context = ParallelContext.get_instance()
diff --git a/colossalai/engine/__init__.py b/colossalai/engine/__init__.py
index 7e5592236..73ccb094e 100644
--- a/colossalai/engine/__init__.py
+++ b/colossalai/engine/__init__.py
@@ -1,7 +1,5 @@
from ._base_engine import Engine
from .gradient_handler import *
-from .schedule import *
-from .amp import *
__all__ = ['Engine']
diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py
index a99aa91e7..8a3f6eac3 100644
--- a/colossalai/engine/_base_engine.py
+++ b/colossalai/engine/_base_engine.py
@@ -1,17 +1,17 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+
+import torch
+from typing import List
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.builder import build_gradient_handler
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
- ZeroRedundancyOptimizer_Level_3)
-from .schedule import BaseSchedule
+from colossalai.logging import get_dist_logger
+from colossalai.utils import is_using_ddp, is_using_pp
+from torch import Tensor
class Engine:
@@ -20,74 +20,40 @@ class Engine:
It controls a iteration in training.
:param model: The neural network model
+ :type model: ``torch.nn.Module``
:param optimizer: Optimizer for updating the parameters
- :param step_schedule: Running schedule in :meth:`step`
- :param gradient_accumulation: Steps of gradient accumulation
+ :type optimizer: ``torch.optim.Optimizer``
+ :param criterion: Loss function for calculating loss
+ :type criterion: ``torch.nn.modules.loss._Loss``
:param gradient_clipping: The norm of gradient clipping
- :type model: Module
- :type optimizer: Optimizer
- :type step_schedule: BaseSchedule, optional
- :type gradient_accumulation: int, optional
:type gradient_clipping: float, optional
+ :param verbose: whether to display log info
+ :type verbose: bool
"""
def __init__(self,
model: Module,
optimizer: Optimizer,
criterion: _Loss,
- step_schedule: BaseSchedule,
- gradient_handlers: list = None,
- gradient_accumulation: int = 1,
- gradient_clipping: float = 0.0,
+ gradient_handlers: List = None,
+ clip_grad_norm: float = 0.0,
+ verbose: bool = True
):
self._model = model
self._optimizer = optimizer
self._criterion = criterion
- self._schedule = step_schedule
-
- # schedule initialize
- self._schedule.initialize(model, optimizer)
+ self._clip_grad_norm = clip_grad_norm
+ self._verbose = verbose
+ self._logger = get_dist_logger()
# state
self.training = True # default
- # gradient accumulation
- assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
- self._grad_accum_size = gradient_accumulation
- self._grad_clip = gradient_clipping
- self._logger = get_global_dist_logger()
-
# build gradient handler
- self._gradient_handlers = []
-
- if gradient_handlers is not None:
- assert isinstance(gradient_handlers, list), \
- f'argument gradient_handler_cfg expected type list, ' \
- f'but got type {type(gradient_handlers)}'
- elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
- ZeroRedundancyOptimizer_Level_3)):
- gradient_handlers = [dict(type='ZeROGradientHandler')]
- self._logger.info(
- "Training with zero is detected, ZeROGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
- ParallelMode.DATA) > 1:
- gradient_handlers = [dict(type='DataParallelGradientHandler')]
- self._logger.info(
- "Data parallel training is detected, DataParallelGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
-
- if gradient_handlers is None:
- self._logger.warning(
- "No gradient handler is set up, please make sure you do not need "
- "to all-reduce the gradients after a training step.",
- ranks=[0])
+ if gradient_handlers:
+ self._gradient_handlers = gradient_handlers
else:
- for cfg in gradient_handlers:
- handler = build_gradient_handler(cfg, model, optimizer)
- self._gradient_handlers.append(handler)
+ self._gradient_handlers = []
@property
def model(self):
@@ -105,11 +71,27 @@ class Engine:
def schedule(self):
return self._schedule
- @property
- def gradient_accumulation(self):
- return self._grad_accum_size
+ def zero_grad(self):
+ self.optimizer.zero_grad()
- def handle_gradient(self):
+ def step(self):
+ self._all_reduce_gradients()
+ self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
+ self.optimizer.step()
+
+ def backward(self, loss: Tensor):
+ return self.optimizer.backward(loss)
+
+ def backward_by_grad(self, tensor, grad):
+ return self.optimizer.backward_by_grad(tensor, grad)
+
+ def calc_loss(self, *args, **kwargs):
+ return self.criterion(*args, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ return self.model(*args, **kwargs)
+
+ def _all_reduce_gradients(self):
"""Handles all-reduce operations of gradients across different parallel groups.
"""
for handler in self._gradient_handlers:
@@ -126,51 +108,3 @@ class Engine:
"""
self.training = False
self._model.eval()
-
- def step(self,
- data_iter,
- is_last_iteration: bool = False,
- return_loss=True):
- """A running step based on the schedule. Usually, it runs a training or
- evaluation over a batch of dataset.
-
- :param data_iter: Data iterator of the dataset
- :param is_last_iteration: If True, this iteration is the last iteration in the epoch
- :param return_loss: loss will be returned if True
- :type data_iter: Iterator
- :type is_last_iteration: bool, optional
- :type return_loss: bool, optional
- :return: (output, lablel, loss)
- """
- if self.training:
- self._optimizer.zero_grad()
-
- # differentiate training and eval with grad accum
- if self.training:
- for i in range(self._grad_accum_size):
- output, label, loss = self._schedule.forward_backward_step(
- data_iter, self._model, self._criterion, self._optimizer,
- forward_only=False,
- grad_accum_size=self._grad_accum_size,
- return_loss=return_loss)
-
- if i == self._grad_accum_size - 1:
- # all reduce gradients
- self.handle_gradient()
- self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
- else:
- output, label, loss = self._schedule.forward_backward_step(
- data_iter, self._model, self._criterion, self._optimizer,
- forward_only=True,
- grad_accum_size=1,
- return_loss=return_loss)
-
- # consume the remaining dataset left out due to gradient accumulation
- if is_last_iteration:
- while True:
- try:
- _ = next(data_iter)
- except StopIteration:
- break
-
- return output, label, loss
diff --git a/colossalai/engine/amp/__init__.py b/colossalai/engine/amp/__init__.py
deleted file mode 100644
index 927d5cf09..000000000
--- a/colossalai/engine/amp/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .grad_scaler import GradScaler
-from .amp_type import AMP_TYPE
diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/engine/schedule/__init__.py
index dba95469b..a885a672e 100644
--- a/colossalai/engine/schedule/__init__.py
+++ b/colossalai/engine/schedule/__init__.py
@@ -1,5 +1,5 @@
from ._base_schedule import BaseSchedule
-from ._no_pipeline import NoPipelineSchedule
-from ._pipeline import PipelineSchedule
+from ._pipeline_schedule import PipelineSchedule
+from ._non_pipeline_schedule import NonPipelineSchedule
-__all__ = ['BaseSchedule', 'NoPipelineSchedule', 'PipelineSchedule']
+__all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule']
diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py
index 0583ccbf3..e28690cb0 100644
--- a/colossalai/engine/schedule/_base_schedule.py
+++ b/colossalai/engine/schedule/_base_schedule.py
@@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
import torch
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
+from torch import Tensor
+from typing import Iterable, Union, List, Callable
+from .._base_engine import Engine
+from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
@@ -18,8 +20,9 @@ class BaseSchedule(ABC):
control of FP16 in class schedule.
"""
- def __init__(self):
- self.logger = get_global_dist_logger()
+ def __init__(self, batch_data_process_func: Callable = None):
+ self.logger = get_dist_logger()
+ self.batch_data_process_func = batch_data_process_func
@staticmethod
def _move_tensor(element):
@@ -35,6 +38,11 @@ class BaseSchedule(ABC):
data = data.to(get_current_device()).detach()
return data
+ def _to_list(self, data):
+ if torch.is_tensor(data):
+ return [data]
+ return data
+
def load_batch(self, data_iter):
"""Loads a batch from data iterator. It returns the data and labels which are
already in the same GPU as where the model's.
@@ -44,46 +52,34 @@ class BaseSchedule(ABC):
"""
if data_iter is None:
raise RuntimeError('Dataloader is not defined.')
- data, label = next(data_iter)
+ batch_data = next(data_iter)
+
+ if self.batch_data_process_func:
+ data, label = self.batch_data_process_func(batch_data)
+ else:
+ data, label = batch_data
+
+ data, label = self._to_list(data), self._to_list(label)
return self._move_to_device(data), self._move_to_device(label)
- def initialize(self, model, optimizer):
- """Initializes the model and the optimizer before training.
- This is often used in FP16 training.
-
- :param model: The neural network model
- :param optimizer: Optimizer for updating the parameters
+ def pre_processing(self, engine: Engine):
+ """To perform actions before running the schedule.
"""
- return model, optimizer
+ pass
@abstractmethod
def forward_backward_step(self,
- data_iter,
- model,
- criterion,
- optimizer=None,
- forward_only=False,
- grad_accum_size: int = 1,
- return_loss=True):
+ engine: Engine,
+ data_iter: Iterable,
+ forward_only: bool,
+ return_loss: bool = True
+ ):
"""The process function over a batch of dataset for training or evaluation.
- :param data_iter: Data iterator of the dataset
- :param model: Model used in training or evaluation
- :param optimizer: Optimizer used in training or evaluation
- :param criterion: Loss function
+ :param engine: Colossalai training engine
+ :param inputs: input data
+ :param labels: ground truth
:param forward_only: If True, the process won't include backward
- :param grad_accum_size: Steps of gradient accumulation
:param return_loss: If False, the loss won't be returned
"""
- pass
-
- @abstractmethod
- def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
- """Updates the parameters with the optimizer.
-
- :param model: The neural network model
- :param optimizer: Optimizer for updating the parameters
- :param grad_clipping: The norm of gradient clipping
- :type grad_clipping: float, optional
- """
- pass
+ pass
\ No newline at end of file
diff --git a/colossalai/engine/schedule/_no_pipeline.py b/colossalai/engine/schedule/_no_pipeline.py
deleted file mode 100644
index 4f38e6cda..000000000
--- a/colossalai/engine/schedule/_no_pipeline.py
+++ /dev/null
@@ -1,188 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-try:
- import apex.amp as apex_amp
-except:
- pass
-
-try:
- import torch.cuda.amp as torch_amp
-except:
- pass
-
-from typing import Iterable
-
-import torch.nn as nn
-from torch.optim import Optimizer
-
-from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
- ZeroRedundancyOptimizer_Level_3)
-from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
-from ._base_schedule import BaseSchedule
-from ._utils import convert_to_fp16, convert_to_fp32
-from ..amp import AMP_TYPE, GradScaler
-
-
-class NoPipelineSchedule(BaseSchedule):
- """A helper schedule class for no pipeline parallelism running environment.
- During one process, it loads a batch of dataset and feeds it to the model.
- After getting the output and calculating the loss, it will use :meth:`step`
- to update the parameters if it is in training mode.
-
- :param amp_type: The type of automatic mixed precision
- :param amp_config: The configuration of automatic mixed procision
- :type amp_type: AMP_TYPE
- :type amp_config: dict
- """
-
- def __init__(
- self,
- amp_type: AMP_TYPE = None,
- amp_config: dict = None,
- ):
- super().__init__()
-
- # mixed precision training
- assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
- 'unrecognised value for argument fp16, it can only be None, torch or apex'
-
- self.use_zero_level_2_3 = False
-
- if amp_type is not None:
- self.fp16 = True
- self.amp_type = amp_type
-
- if amp_config is not None:
- assert isinstance(amp_config, dict), \
- f'expected argument fp16_config to be type dictionary, but got {type(amp_config)}'
-
- if self.amp_type == AMP_TYPE.TORCH:
- # torch apex
- if amp_config is None:
- amp_config = dict()
- self.amp_cfg = amp_config
- elif self.amp_type == AMP_TYPE.APEX:
- # apex amp
- if amp_config is None:
- amp_config = dict(opt_level='O2')
- self.logger.warning(
- 'apex is deprecated, please consider using torch.cuda.amp instead.'
- )
- self.amp_cfg = amp_config
- elif self.amp_type == AMP_TYPE.PARALLEL:
- # use fp16 optimizer for tensor parallel training
- if amp_config is None:
- amp_config = dict()
- self.amp_cfg = amp_config
- else:
- self.fp16 = False
- self.amp_type = None
-
- def initialize(self, model: nn.Module, optimizer: Optimizer):
- if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
- ZeroRedundancyOptimizer_Level_3)):
- self.use_zero_level_2_3 = True
- assert self.amp_type != AMP_TYPE.PARALLEL, \
- 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
-
- if self.fp16:
- if self.amp_type == AMP_TYPE.TORCH:
- self._torch_amp_scaler = GradScaler(**self.amp_cfg)
- elif self.amp_type == AMP_TYPE.APEX:
- model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
-
- return model, optimizer
-
- def forward_backward_step(self,
- data_iter: Iterable,
- model: nn.Module,
- criterion: nn.modules.loss._Loss,
- optimizer: Optimizer = None,
- forward_only: bool = False,
- grad_accum_size: int = 1,
- return_loss: bool = True):
- """The process function that loads loads a batch of dataset and feeds it to the model.
- The returned labels and loss will None if :attr:`return_loss` is False.
-
- :param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
- :param model: Model for training and inference
- :param criterion: Loss function for training
- :param optimizer: Optimizer used for training
- :param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
- :param grad_accum_size: The number of iterations for gradient accumulation
- :param return_loss: Loss will be returned if True
- :type data_iter: Iterator
- :type model: torch.nn.Module
- :type criterion: torch.nn.modules.loss._Loss
- :type optimizer: torch.optim.Optimizer
- :type forward_only: bool, optional
- :type grad_accum_size: int
- :type return_loss: bool, optional
- :return: (output, label, loss)
- """
- assert forward_only or return_loss, \
- 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
-
- data, label = self.load_batch(data_iter)
- loss = None
-
- # forward
- if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
- with torch_amp.autocast():
- output = model(*data)
- if not isinstance(output, (tuple, list)):
- output = (output,)
- if return_loss:
- loss = criterion(*output, *label)
- else:
- if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
- data = convert_to_fp16(data)
-
- output = model(*data)
-
- if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
- output = convert_to_fp32(output)
-
- if not isinstance(output, (tuple, list)):
- output = (output,)
- if return_loss:
- loss = criterion(*output, *label)
-
- loss /= grad_accum_size
-
- if not forward_only:
- # backward
- if self.use_zero_level_2_3:
- optimizer.backward(loss)
- elif self.fp16:
- if self.amp_type == AMP_TYPE.APEX:
- with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- elif self.amp_type == AMP_TYPE.TORCH:
- self._torch_amp_scaler.scale(loss).backward()
- elif self.amp_type == AMP_TYPE.PARALLEL:
- loss = optimizer.scale_loss(loss)
- loss.backward()
- # scale back to display the original value in logs
- loss.div_(optimizer.grad_scaler.scale)
- else:
- loss.backward()
-
- if return_loss:
- return output, label, loss * grad_accum_size
- else:
- return output, None, None
-
- def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
- # step optimizer
- if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
- if grad_clipping > 0.0:
- self._torch_amp_scaler.unscale_(optimizer)
- clip_grad_norm_fp32(model.parameters(), grad_clipping)
- self._torch_amp_scaler.step(optimizer)
- self._torch_amp_scaler.update()
- else:
- if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
- clip_grad_norm_fp32(model.parameters(), grad_clipping)
- optimizer.step()
diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/engine/schedule/_non_pipeline_schedule.py
new file mode 100644
index 000000000..01e681941
--- /dev/null
+++ b/colossalai/engine/schedule/_non_pipeline_schedule.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+from typing import Iterable
+
+import torch
+
+import torch.nn as nn
+from colossalai.engine import Engine
+from torch.optim import Optimizer
+from ._base_schedule import BaseSchedule
+from colossalai.utils import conditional_context
+
+
+class NonPipelineSchedule(BaseSchedule):
+ """A helper schedule class for no pipeline parallelism running environment.
+ During one process, it loads a batch of dataset and feeds it to the model.
+ After getting the output and calculating the loss, it will use :meth:`step`
+ to update the parameters if it is in training mode.
+ :param amp_type: The type of automatic mixed precision
+ :param amp_config: The configuration of automatic mixed procision
+ :type amp_type: AMP_TYPE
+ :type amp_config: dict
+ """
+
+ def forward_backward_step(self,
+ engine: Engine,
+ data_iter: Iterable,
+ forward_only: bool = False,
+ return_loss: bool = True):
+ """The process function that loads loads a batch of dataset and feeds it to the model.
+ The returned labels and loss will None if :attr:`return_loss` is False.
+ :param engine: Model for training and inference
+ :param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
+ :param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
+ :param return_loss: Loss will be returned if True
+ :type engine: Iterator
+ :type data_iter: Iterator
+ :type forward_only: bool, optional
+ :type return_loss: bool, optional
+ :return: (output, label, loss)
+ """
+ assert forward_only or return_loss, \
+ "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
+ data, label = self.load_batch(data_iter)
+
+ # forward
+ with conditional_context(torch.no_grad(), enable=forward_only):
+ output = engine(*data)
+ if not isinstance(output, (tuple, list)):
+ output = (output,)
+ if return_loss:
+ loss = engine.criterion(*output, *label)
+
+ if not forward_only:
+ engine.backward(loss)
+
+ if return_loss:
+ return output, label, loss
+ else:
+ return output, None, None
diff --git a/colossalai/engine/schedule/_pipeline.py b/colossalai/engine/schedule/_pipeline_schedule.py
similarity index 77%
rename from colossalai/engine/schedule/_pipeline.py
rename to colossalai/engine/schedule/_pipeline_schedule.py
index 6defea93d..f0bc04427 100644
--- a/colossalai/engine/schedule/_pipeline.py
+++ b/colossalai/engine/schedule/_pipeline_schedule.py
@@ -10,12 +10,12 @@ from torch import Tensor
from colossalai.communication import *
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
- ZeroRedundancyOptimizer_Level_3)
+from colossalai.amp.naive_amp import NaiveAMPModel
+from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
+ ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule
-from ._utils import convert_to_fp16
-from ..amp import AMP_TYPE
+from colossalai.amp import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]):
@@ -28,32 +28,25 @@ def squeeze(x: Union[Tensor, tuple, list]):
class PipelineSchedule(BaseSchedule):
"""A helper schedule class for pipeline parallelism running environment.
It uses non-interleaved 1F1B strategy. Other properties are similar as
- :class:`NoPipelineSchedule`.
+ :class:`NonPipelineSchedule`.
:param num_microbatches: The number of microbatches
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
+ :param sync_data: If set to `True`, will sync data every batch over pipeline stages
:type num_microbatches: int
:type amp_type: AMP_TYPE
:type amp_config: dict
+ :type sync_data: bool
"""
def __init__(self,
num_microbatches,
- amp_type: AMP_TYPE = None,
- amp_config: dict = None):
+ sync_data: bool = True):
super().__init__()
self.num_microbatches = num_microbatches
- self.data_sync = True # close after making sure data is identical
-
- # amp
- # LSGL: amp_config is not used, but leave here for future extension
- self.amp_type = amp_type
- self.amp_config = amp_config
-
- if self.amp_type is not None:
- assert self.amp_type == AMP_TYPE.PARALLEL, 'We only support AMP_TYPE.PARALLEL for pipeline training for now'
+ self.sync_data = sync_data
def _move_to_device(self, data):
if isinstance(data, (
@@ -67,30 +60,37 @@ class PipelineSchedule(BaseSchedule):
return data
def _sync_data(self):
+ reqs = []
if gpc.is_first_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_global_rank()
- dist.broadcast(
+ reqs.append(dist.broadcast(
tensor=self.batch_data,
src=src_rank,
- group=gpc.get_group(ParallelMode.PIPELINE_PREV)
- )
- dist.broadcast(
+ group=gpc.get_group(ParallelMode.PIPELINE_PREV),
+ async_op=True
+ ))
+ reqs.append(dist.broadcast(
tensor=self.batch_label,
src=src_rank,
- group=gpc.get_group(ParallelMode.PIPELINE_PREV)
- )
+ group=gpc.get_group(ParallelMode.PIPELINE_PREV),
+ async_op=True
+ ))
if gpc.is_last_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
- dist.broadcast(
+ reqs.append(dist.broadcast(
tensor=self.batch_data,
src=src_rank,
- group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
- )
- dist.broadcast(
+ group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
+ async_op=True
+ ))
+ reqs.append(dist.broadcast(
tensor=self.batch_label,
src=src_rank,
- group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
- )
+ group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
+ async_op=True
+ ))
+ for req in reqs:
+ req.wait()
# Pipeline schedule just puts data in memory
def load_batch(self, data_iter):
@@ -104,7 +104,7 @@ class PipelineSchedule(BaseSchedule):
assert batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = batch_size // self.num_microbatches
- if self.data_sync:
+ if self.sync_data:
self._sync_data()
def _get_data_slice(self, tensor):
@@ -116,21 +116,20 @@ class PipelineSchedule(BaseSchedule):
self.batch_pos += self.microbatch_size
return (data,), (label,)
- def initialize(self, model, optimizer):
- if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
+ def pre_processing(self, engine):
+ if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
# LSG: set default dtype to fp16 for communication
- if self.amp_type == AMP_TYPE.PARALLEL:
+ if isinstance(engine.model, NaiveAMPModel):
torch.set_default_dtype(torch.half)
- self.logger.info(
+ self.logger.warning(
'default tensor dtype is set to torch.half for fp16 training',
ranks=[0])
- def forward_step(self, model, criterion, input_tensor, return_tensors,
- grad_accum_size, return_loss=True):
+ def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
@@ -138,17 +137,16 @@ class PipelineSchedule(BaseSchedule):
if input_tensor is None:
input_tensor, label = self.load_micro_batch()
- if self.amp_type == AMP_TYPE.PARALLEL:
- input_tensor = convert_to_fp16(input_tensor)
input_tensor = squeeze(input_tensor)
- output_tensor = model(input_tensor)
+ output_tensor = engine(input_tensor)
output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss:
input_tensor, label = self.load_micro_batch()
- loss_reduced = criterion(output_tensor, *label) \
- / (self.num_microbatches * grad_accum_size)
+ loss_reduced = engine.criterion(output_tensor, *label) \
+ / self.num_microbatches
+
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
@@ -159,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
else:
return output_tensor
- def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
+ def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
@@ -171,9 +169,10 @@ class PipelineSchedule(BaseSchedule):
input_tensor.retain_grad()
# Backward pass.
- if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
- output_tensor = optimizer.scale_loss(output_tensor)
- torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
+ if output_tensor_grad is None:
+ engine.backward(output_tensor)
+ else:
+ engine.backward_by_grad(output_tensor, output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
@@ -183,12 +182,9 @@ class PipelineSchedule(BaseSchedule):
return input_tensor_grad
def forward_backward_step(self,
+ engine,
data_iter,
- model,
- criterion,
- optimizer=None,
forward_only=False,
- grad_accum_size: int = 1,
return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
@@ -226,9 +222,8 @@ class PipelineSchedule(BaseSchedule):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
output_tensor = self.forward_step(
- model, criterion,
- input_tensor, return_tensors,
- grad_accum_size, return_loss=return_loss
+ engine, input_tensor, return_tensors,
+ return_loss=return_loss
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
@@ -252,9 +247,8 @@ class PipelineSchedule(BaseSchedule):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(
- model, criterion,
- input_tensor, return_tensors,
- grad_accum_size, return_loss=return_loss
+ engine, input_tensor, return_tensors,
+ return_loss=return_loss
)
if forward_only:
send_forward(output_tensor)
@@ -276,7 +270,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(
- optimizer,
+ engine,
input_tensor, output_tensor,
output_tensor_grad
)
@@ -297,7 +291,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor_grad = recv_backward(bt_shape)
input_tensor_grad = self.backward_step(
- optimizer,
+ engine,
input_tensor, output_tensor,
output_tensor_grad
)
@@ -309,11 +303,8 @@ class PipelineSchedule(BaseSchedule):
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
- sum(loss) * grad_accum_size)
+ sum(loss))
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:
return tuple((None, None, None))
-
- def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
- optimizer.step()
diff --git a/colossalai/engine/schedule/_utils.py b/colossalai/engine/schedule/_utils.py
deleted file mode 100644
index cdfd0246c..000000000
--- a/colossalai/engine/schedule/_utils.py
+++ /dev/null
@@ -1,27 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from typing import Union, List
-
-from torch import Tensor
-
-
-def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
- if isinstance(data, Tensor):
- ret = data.half()
- elif isinstance(data, (list, tuple)):
- ret = [val.half() for val in data]
- else:
- raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
- return ret
-
-
-def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
- if isinstance(data, Tensor):
- ret = data.float()
- elif isinstance(data, (list, tuple)):
- ret = [val.float() for val in data]
- else:
- raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
- return ret
-
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index 6806d86eb..5d7087841 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -3,377 +3,326 @@
import argparse
import pprint
-import random
-from pathlib import Path
-from typing import Callable, Iterable, Optional, Union
-from typing import Tuple
-
+import os
+from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
import numpy as np
import torch
-from torch.utils.data import DataLoader
+import torch.nn as nn
-from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
+from pathlib import Path
+from typing import Iterable, Union, Optional, Tuple, List, Dict
+
+from colossalai.amp import convert_to_amp, AMP_TYPE
+from colossalai.context import Config, ParallelMode, ConfigException
+from colossalai.core import global_context as gpc
from colossalai.engine import Engine
-from colossalai.logging import get_global_dist_logger, init_global_dist_logger
-from colossalai.nn import DataParallelSampler
-from colossalai.nn.model.base_model import BaseModel
-from .builder import (ModelInitializer, build_dataset, build_loss,
- build_model, build_optimizer,
- build_optimizer_wrapper, build_schedule)
-from .context import Config, ParallelMode
-from .core import global_context as gpc
-from .utils import get_current_device, sync_model_param_in_dp
+from colossalai.logging import get_dist_logger
+from colossalai.utils import (accumulate_gradient, get_current_device,
+ sync_model_param_in_dp, is_using_ddp, is_using_pp)
+from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
+from colossalai.builder.builder import build_gradient_handler
+from torch.optim.optimizer import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader
+from torch.nn.modules.loss import _Loss
+from torch.nn.parallel import DistributedDataParallel as DDP
-def parse_args():
+def get_default_parser():
'''Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
- :return: call the parse arguments function
+ :return: returns the parser with the default arguments, the user may add customized arguments into this parser
:rtype: Namespace
'''
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file')
parser.add_argument('--host',
type=str,
- default=None,
help='the master address for distributed training')
parser.add_argument('--port',
- type=str,
- default=None,
+ type=int,
help='the master port for distributed training')
- parser.add_argument('--world_size', type=int, help='world size for ')
+ parser.add_argument('--world_size', type=int, help='world size for distributed training')
+ parser.add_argument('--rank', type=int, help='rank for the default process group')
parser.add_argument('--local_rank',
type=int,
- help='rank for the default process group')
+ help='local rank on the node')
parser.add_argument('--backend',
type=str,
default='nccl',
- help='backend for torch.distributed')
- return parser.parse_args()
+ help='backend for distributed communication')
+ return parser
-def init_dist(config: Union[str, dict] = None,
- local_rank: int = None,
- world_size: int = None,
- host: str = None,
- port: str = None,
- backend: str = None):
+def launch(config: Union[str, Path, Config, Dict],
+ rank: int,
+ world_size: int,
+ host: str,
+ port: int,
+ backend: str = 'nccl',
+ local_rank: int = None,
+ seed: int = 1024,
+ verbose: bool = True):
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
- Then initialize and set distributed environment by calling global_context's functions.
+ Then initialize and set distributed environment by calling global_context's functions.
:param config: config file or config file path are both acceptable
- :type config: Union[str, dict], optional
- :param local_rank: rank for the default process group, defaults to None
+ :type config: Union[str, dict, Config]
+ :param rank: rank for the default process group
+ :type rank: int
+ :param world_size: world size of the default process group
+ :type world_size: int
+ :param host: the master address for distributed training
+ :type host: str
+ :param port: the master port for distributed training
+ :type port: str
+ :param backend: backend for torch.distributed
+ :type backend: str
+ :param local_rank: rank for the process on the node and is used to set the default CUDA device,
+ defaults to None. If local_rank = None, the default device ordinal will be calculated automatically
:type local_rank: int, optional
- :param world_size: world size of GPUs, defaults to None
- :type world_size: int, optional
- :param host: the master address for distributed training, defaults to None
- :type host: str, optional
- :param port: the master port for distributed training, defaults to None
- :type port: str, optional
- :param backend: backend for torch.distributed, defaults to None
- :type backend: str, optional
:raises Exception: raise exception when config type is wrong
'''
- args = [config, local_rank, world_size, host, port, backend]
- arg_given = [arg is not None for arg in args]
-
- if not all(arg_given):
- args = parse_args()
-
- if config is None:
- config = args.config
- if local_rank is None:
- local_rank = args.local_rank
- if world_size is None:
- world_size = args.world_size
- if host is None:
- host = args.host
- if port is None:
- port = args.port
- if backend is None:
- backend = args.backend
- args = Config(
- dict(config=config,
- host=host,
- port=port,
- world_size=world_size,
- local_rank=local_rank,
- backend=backend))
-
- # set distributed settings
- dist_args = Config(
- dict(local_rank=args.local_rank,
- world_size=args.world_size,
- backend=args.backend))
-
- gpc.set_dist_args(dist_args)
+ gpc.verbose = verbose
# set config
- if isinstance(args.config, dict):
- cfg = args.config
- elif isinstance(args.config, (str, Path)):
- cfg = Config.from_file(args.config)
- else:
- raise Exception('Config type error: {}'.format(type(args.config)))
- gpc.load_config(cfg)
+ assert isinstance(config, (Config, str, Path, dict)), \
+ f'expected argument config to be Config, str or Path, but got {type(config)}'
+ if not isinstance(config, Config) and isinstance(config, dict):
+ config = Config(config)
+ if isinstance(config, (str, Path)):
+ config = Config.from_file(config)
+ gpc.load_config(config)
- # init dist groups
- gpc.init_global_dist(args.host, args.port)
+ # init default process group
+ gpc.init_global_dist(rank, world_size, backend, host, port)
+
+ # init process groups for different parallel modes from config
gpc.init_parallel_groups()
- # init dist logger
- init_global_dist_logger()
-
# set cuda device
if torch.cuda.is_available():
- gpc.set_device()
+ # if local rank is not given, calculate automatically
+ gpc.set_device(local_rank)
+
+ gpc.set_seed(seed)
+
+ if verbose:
+ logger = get_dist_logger()
+ logger.info(f'Distributed environment is initialized, '
+ f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
+ f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
-def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs):
- '''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
-
- .. note: when pipeline parallel is enabled, shuffle cannot be True
- as it will result in mismatch between input data on the 1st
- stage and label on the last stage
-
- :param dataset: a :class:utils.data.dataset dataset
- :param seed: random worker seed, defaults to 1024
- :type seed: int, optional
- :param add_sampler_if_possible: [description], defaults to False
- :type add_sampler_if_possible: bool, optional
- :return: a :class:utils.data.dataset dataloader
- :rtype: torch.utils.data.dataset
- '''
- _kwargs = kwargs.copy()
- if 'shuffle' in _kwargs:
- shuffle = _kwargs.pop('shuffle')
- else:
- shuffle = False
-
- if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
- sampler = DataParallelSampler(dataset, shuffle=shuffle)
- else:
- sampler = None
-
- # Deterministic dataloader
- def seed_worker(worker_id):
- worker_seed = seed
- np.random.seed(worker_seed)
- torch.manual_seed(worker_seed)
- random.seed(worker_seed)
-
- if sampler is None:
- return DataLoader(dataset,
- worker_init_fn=seed_worker,
- shuffle=shuffle,
- **_kwargs)
- else:
- return DataLoader(dataset,
- sampler=sampler,
- worker_init_fn=seed_worker,
- **_kwargs)
+def launch_from_slurm(config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = 'nccl',
+ seed: int = 1024,
+ verbose: bool = True):
+ rank = int(os.environ['SLURM_PROCID'])
+ world_size = int(os.environ['SLURM_NPROCS'])
+ launch(config=config,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose)
-def initialize(config: Union[str, dict] = None,
- local_rank: int = None,
- world_size: int = None,
- host: str = None,
- port: str = None,
- backend: str = None,
- train_dataloader: Optional[Union[Iterable, Callable]] = None,
- test_dataloader: Optional[Union[Iterable, Callable]] = None,
+def launch_from_openmpi(config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = 'nccl',
+ seed: int = 1024,
+ verbose: bool = True):
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ launch(config=config,
+ local_rank=local_rank,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose)
+
+
+def launch_from_torch(config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = 'nccl',
+ seed: int = 1024,
+ verbose: bool = True):
+ rank = int(os.environ['RANK'])
+ local_rank = int(os.environ['LOCAL_RANK'])
+ world_size = int(os.environ['WORLD_SIZE'])
+ launch(config=config,
+ local_rank=local_rank,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose)
+
+
+def initialize(model: Union[nn.Module, List[nn.Module]],
+ optimizer: Union[Optimizer, List[Optimizer]],
+ criterion: Union[_Loss, List[_Loss]],
+ train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
+ test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
+ lr_scheduler: _LRScheduler = None,
+ verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader]:
- '''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
+ ''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
- :param config: config file or config file path are both acceptable
- :type config: Union[str, dict], optional
- :param local_rank: rank for the default process group, defaults to None
- :type local_rank: int, optional
- :param world_size: world size of GPUs, defaults to None
- :type world_size: int, optional
- :param host: the master address for distributed training, defaults to None
- :type host: str, optional
- :param port: the master port for distributed training, defaults to None
- :type port: str, optional
- :param backend: backend for torch.distributed, defaults to None
- :type backend: str, optional
- :param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
- :type train_dataloader: Optional[Union[Iterable, Callable]], optional
- :param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
- :type test_dataloader: Optional[Union[Iterable, Callable]], optional
- :return: (engine, train_dataloader, test_dataloader, criterion)
+ :param model: your model instance
+ :type model: a single or a list of ``torch.nn.Module`` objects
+ :param optimizer: your optimizer instance
+ :type optimizer: a single or a list of ``torch.optim.optimizer.Optimizer`` objects
+ :param criterion: your criterion instance
+ :type criterion: a single or a list of ``torch.nn.modules.loss._Loss`` objects
+ :param train_dataloader: dataloaders for training data
+ :type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
+ :param train_dataloader: dataloaders for testing data
+ :type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
+ :return: (engine, criterion, train_dataloader, test_dataloader)
:rtype: tuple
'''
- # initialize distributed environment
- init_dist(config=config,
- local_rank=local_rank,
- world_size=world_size,
- host=host,
- port=port,
- backend=backend)
+ # get logger
+ logger = get_dist_logger()
+ gpc.verbose = verbose
- # init logger
- logger = get_global_dist_logger()
- logger.info(f'Distributed environment is initialized, '
- f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
- f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
+ # get config from gpc
+ config = gpc.config
# print config
- logger.info(f"\n========== Your Config ========\n"
- f"{pprint.pformat(gpc.config)}\n"
- f"================================", ranks=[0])
+ if verbose:
+ logger.info(f"\n========== Your Config ========\n"
+ f"{pprint.pformat(gpc.config)}\n"
+ f"================================\n", ranks=[0])
# cudnn
- cudnn_benchmark = gpc.config.get('cudnn_benchmark', True)
- cudnn_deterministic = gpc.config.get('cudnn_deterministic', False)
+ cudnn_benchmark = config.get('cudnn_benchmark', True)
+ cudnn_deterministic = config.get('cudnn_deterministic', False)
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.deterministic = cudnn_deterministic
- logger.info(
- f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
+ if verbose:
+ logger.info(
+ f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
- # set seed, cuda seed is only set when cuda is avail
- gpc.set_seed()
-
- # return_items = list()
-
- # check fp16 and zero
- should_convert_model_to_half = False
- should_wrap_fp16_optimizer = False
- should_wrap_zero_optimizer_level_2_3 = False
-
- if hasattr(gpc.config, 'fp16'):
- fp16_mode = gpc.config.fp16.mode
- if fp16_mode == AMP_TYPE.PARALLEL:
- should_convert_model_to_half = True
- should_wrap_fp16_optimizer = True
-
- if hasattr(gpc.config, 'zero'):
- should_wrap_zero_optimizer_level_2_3 = True
- zero_type = gpc.config.zero.type
- if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']:
- should_convert_model_to_half = True
- assert not should_wrap_fp16_optimizer, \
- 'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3'
-
- # build model
- logger.info('Building model ...', ranks=[0])
- assert hasattr(
- gpc.config, 'model'), "Build error: configuration 'model' is missing"
- if gpc.pipeline_parallel_size > 1:
- model = ModelInitializer(gpc.config.model, 1, verbose=True)
- model = model.model_initialize()
- else:
- model = build_model(gpc.config.model)
- if isinstance(model, BaseModel):
- model.build_from_cfg()
- model = model.to(get_current_device())
+ # first sync model across dp ranks
+ model.to(get_current_device())
sync_model_param_in_dp(model)
- logger.info('Model is created', ranks=[0])
- if should_convert_model_to_half:
- model = model.half()
- logger.info("Model is cast to fp16", ranks=[0])
+ # check amp and zero
+ fp16_cfg = gpc.config.get('fp16', None)
+ zero_cfg = gpc.config.get('zero', None)
- # training data
- if callable(train_dataloader):
- logger.info(
- f'Build train data loader from {train_dataloader}', ranks=[0])
- train_dataloader = train_dataloader()
- if train_dataloader is None and hasattr(gpc.config, 'train_data'):
- logger.info('Preparing data ...', ranks=[0])
- # assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing."
- train_dataset = build_dataset(gpc.config.train_data.dataset)
- logger.info('Train dataset is ready.', ranks=[0])
+ if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
+ raise ConfigException(
+ "It is not allowed to set fp16 and zero configuration in your config file at the same time")
- train_dataloader = get_dataloader(train_dataset,
- gpc.config.get('seed', 1024),
- True,
- **gpc.config.train_data.dataloader,
- )
- logger.info(
- f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0])
+ # initialize amp
+ amp_mode = None
+ if fp16_cfg is not None and fp16_cfg.mode is not None:
+ cfg_ = fp16_cfg.copy()
+ amp_mode = cfg_.pop('mode')
+ model, optimizer, criterion = convert_to_amp(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ mode=amp_mode,
+ amp_config=cfg_)
- if callable(test_dataloader):
- logger.info(
- f'Build test data loader from {test_dataloader}', ranks=[0])
- test_dataloader = test_dataloader()
- # testing data, allowed to be None
- if test_dataloader is None and hasattr(gpc.config, 'test_data'):
- test_dataset = build_dataset(gpc.config.test_data.dataset)
- test_dataloader = get_dataloader(
- test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader)
- logger.info(
- f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0])
+ if zero_cfg is not None:
+ cfg_ = zero_cfg.copy()
+ level = cfg_.pop('level')
+ model, optimizer = convert_to_zero(model=model,
+ optimizer=optimizer,
+ level=level,
+ zero_config=cfg_
+ )
- # build loss function
- assert hasattr(gpc.config, 'loss'), \
- 'Build error: configuration \'loss\' is missing.'
- criterion = build_loss(gpc.config.loss)
- logger.info('Loss function is created', ranks=[0])
-
- # build optimizer
- assert hasattr(gpc.config, 'optimizer'), \
- "Build error: configuration 'optimizer' is missing."
- optim_type = gpc.config.optimizer.type
- is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer'
- if is_pytorch_native_zero_level_1:
- original_cfg_copy = gpc.config.optimizer.copy()
- original_cfg_copy.pop('type')
- cfg = dict(type=optim_type, process_group=gpc.get_group(
- ParallelMode.DATA), **original_cfg_copy)
- optimizer = build_optimizer(cfg, model)
+ # gradient handler
+ gradient_handler_cfg = gpc.config.get('gradient_handler', None)
+ if gradient_handler_cfg is None:
+ # if gradient handler is not specified in the configuration file,
+ # check in the following order
+ # 1. if optimizer is ZERO, then use zero grad handler
+ # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
+ # 3. if using pipeline and dp size larger than 1, use data parallel grad handler
+ if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
+ ZeroRedundancyOptimizer_Level_3)):
+ gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
+ if verbose:
+ logger.info(
+ "Training with zero is detected, ZeROGradientHandler is automatically "
+ "added even though not specified in the configuration",
+ ranks=[0])
+ elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
+ model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
+ if verbose:
+ logger.info(
+ 'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
+ elif is_using_ddp():
+ gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
+ if verbose:
+ logger.info(
+ "Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
+ "added even though not specified in the configuration",
+ ranks=[0])
else:
- optimizer = build_optimizer(gpc.config.optimizer, model)
+ if not isinstance(gradient_handler_cfg, list):
+ raise ConfigException(
+ f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
- if should_wrap_zero_optimizer_level_2_3:
- optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model)
-
- if should_wrap_fp16_optimizer:
- # replace the field mode with type
- fp16_cfg = gpc.config.fp16.copy()
- amp_type = fp16_cfg.pop('mode')
- assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL'
- fp16_cfg['type'] = 'FP16Optimizer'
- optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
- logger.info('Optimizer is created', ranks=[0])
-
- # build schedule and engine
- if hasattr(gpc.config, 'fp16'):
- amp_type = gpc.config.fp16.mode
- amp_cfg = gpc.config.fp16.copy()
- amp_cfg.pop('mode')
+ if gradient_handler_cfg is None:
+ gradient_handlers = None
+ if verbose and not isinstance(model, DDP):
+ logger.warning(
+ "No PyTorch DDP or gradient handler is set up, please make sure you do not need "
+ "to all-reduce the gradients after a training step.",
+ ranks=[0])
else:
- amp_type = None
- amp_cfg = None
+ gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
- engine_cfg = gpc.config.get('engine', dict())
- schedule_cfg = engine_cfg.pop('schedule', None)
+ # check if optimizer is ColossalaiOptimizer
+ if not isinstance(optimizer, (ColossalaiOptimizer, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
+ optimizer = ColossalaiOptimizer(optim=optimizer)
- schedule_type = None
- if schedule_cfg is not None:
- schedule_type = schedule_cfg.get('type', None)
+ # gradient accumulation
+ grad_accum_size = gpc.config.get('gradient_accumulation', None)
+ if grad_accum_size is not None:
+ optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model,
+ optimizer=optimizer,
+ dataloader=train_dataloader,
+ accumulate_size=grad_accum_size,
+ gradient_handlers=gradient_handlers,
+ lr_scheduler=lr_scheduler)
- if schedule_type is not None:
- # run customized schedule
- schedule_cfg['amp_type'] = amp_type
- schedule_cfg['amp_config'] = amp_cfg
- schedule = build_schedule(schedule_cfg)
- elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
- assert schedule_cfg is not None, \
- "Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
- schedule = PipelineSchedule(
- amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
- else:
- schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
+ # clip grad norm
+ clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
+ if clip_grad_norm > 0:
+ if zero_cfg is not None:
+ raise ConfigException(
+ "clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
+ elif fp16_cfg is not None and fp16_cfg.mode == AMP_TYPE.NAIVE:
+ raise ConfigException(
+ "clip_grad_norm should be specified with AMP_TYPE.NAIVE, you should specify clip_grad in fp16 configuration")
engine = Engine(
model=model,
optimizer=optimizer,
criterion=criterion,
- step_schedule=schedule,
- **gpc.config.get('engine', dict())
+ gradient_handlers=gradient_handlers,
+ clip_grad_norm=clip_grad_norm
)
- return engine, train_dataloader, test_dataloader
+ return engine, train_dataloader, test_dataloader, lr_scheduler
diff --git a/colossalai/logging/__init__.py b/colossalai/logging/__init__.py
index 71657557f..5ee86c45f 100644
--- a/colossalai/logging/__init__.py
+++ b/colossalai/logging/__init__.py
@@ -1,26 +1,10 @@
-from colossalai.core import global_context as gpc
from .logging import DistributedLogger
-__all__ = ['get_global_dist_logger', 'get_dist_logger', 'DistributedLogger', 'init_global_dist_logger']
-
-_GLOBAL_LOGGER: DistributedLogger = None
+__all__ = ['get_dist_logger', 'DistributedLogger']
-def get_dist_logger(name, level='INFO', root_path: str = None, mode='a'):
- return DistributedLogger(name=name, level=level, root_path=root_path, mode=mode)
-
-
-def get_global_dist_logger():
- assert _GLOBAL_LOGGER is not None, 'Global distributed logger is not initialized'
- return _GLOBAL_LOGGER
-
-
-def init_global_dist_logger():
- rank = gpc.get_global_rank()
- if hasattr(gpc.config, 'logging'):
- logger = get_dist_logger(name=f'rank_{rank}', **gpc.config.logging)
- else:
- logger = get_dist_logger(name=f'rank_{rank}', level='INFO')
- global _GLOBAL_LOGGER
- assert _GLOBAL_LOGGER is None, 'Global distributed logger has already been initialized'
- _GLOBAL_LOGGER = logger
+def get_dist_logger(name='root'):
+ """Get logger instance based on name. The DistributedLogger will create singleton instances,
+ which means that only one logger instance is created per name.
+ """
+ return DistributedLogger.get_instance(name=name)
diff --git a/colossalai/logging/logging.py b/colossalai/logging/logging.py
index b8a79c491..69f799f8b 100644
--- a/colossalai/logging/logging.py
+++ b/colossalai/logging/logging.py
@@ -1,11 +1,13 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+import colossalai
import logging
from pathlib import Path
+from typing import Union
from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT)
@@ -16,40 +18,92 @@ class DistributedLogger:
:param name: The name of the logger
:type name: str
- :param level: The threshold for the logger. Logging messages which are less severe than `level`
- will be ignored
- :type level: str
- :param root_path: The root path where logs are stored
- :type root_path: str, optional
- :param mode: The mode that the file is opened in. Defaults to 'a'
- :type mode: str, optional
"""
- def __init__(self, name, level='INFO', root_path: str = None, mode='a'):
- self._logger = logging.getLogger(name)
+ __instances = dict()
+
+ @staticmethod
+ def get_instance(name: str):
+ """Get the unique single logger instance based on name.
+ :param name: The name of the logger
+ :type name: str
+ :return: a DistributedLogger object
+ :rtype: DistributedLogger
+ """
+ if name in DistributedLogger.__instances:
+ return DistributedLogger.__instances[name]
+ else:
+ logger = DistributedLogger(name=name)
+ return logger
+
+ def __init__(self, name):
+ if name in DistributedLogger.__instances:
+ raise Exception('Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
+ else:
+ self._name = name
+ self._logger = logging.getLogger(name)
+ DistributedLogger.__instances[name] = self
+
+ @staticmethod
+ def _check_valid_logging_level(level: str):
+ assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
+
+ def set_level(self, level: str):
+ """Set the logging level
+ :param level: can only be INFO, DEBUG, WARNING and ERROR
+ :type level: str
+ """
+ self._check_valid_logging_level(level)
self._logger.setLevel(getattr(logging, level))
- if root_path is not None:
- log_root_path = Path(root_path)
- # create path if not exists
- log_root_path.mkdir(parents=True, exist_ok=True)
- log_path = log_root_path.joinpath(f'{name}.log')
- file_handler = logging.FileHandler(log_path, mode)
- file_handler.setLevel(getattr(logging, level))
- formatter = logging.Formatter(_FORMAT)
- file_handler.setFormatter(formatter)
- self._logger.addHandler(file_handler)
+ def log_to_file(self,
+ path: Union[str, Path],
+ mode: str = 'a',
+ level: str = 'INFO',
+ suffix: str = None):
+ """Save the logs to file
+ :param path: the file to save the log
+ :type path: a string or pathlib.Path object
+ :param mode: the mode to write log into the file
+ :type mode: str
+ :param level: can only be INFO, DEBUG, WARNING and ERROR
+ :type level: str
+ """
+ assert isinstance(path, (str, Path)), \
+ f'expected argument path to be type str or Path, but got {type(path)}'
+ self._check_valid_logging_level(level)
+ if isinstance(path, str):
+ path = Path(path)
+
+ # set the default file name if path is a directory
+ if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
+ rank = 0
+ else:
+ rank = colossalai.core.global_context.get_global_rank()
+
+ if suffix is not None:
+ log_file_name = f'rank_{rank}_{suffix}.log'
+ else:
+ log_file_name = f'rank_{rank}.log'
+ path = path.joinpath(log_file_name)
+
+ # add file handler
+ file_handler = logging.FileHandler(path, mode)
+ file_handler.setLevel(getattr(logging, level))
+ formatter = logging.Formatter(_FORMAT)
+ file_handler.setFormatter(formatter)
+ self._logger.addHandler(file_handler)
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
if ranks is None:
getattr(self._logger, level)(message)
else:
- local_rank = gpc.get_local_rank(parallel_mode)
+ local_rank = colossalai.core.global_context.get_local_rank(parallel_mode)
if local_rank in ranks:
getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
- """Stores an info log message.
+ """Log an info message.
:param message:
:type message:
@@ -61,7 +115,7 @@ class DistributedLogger:
self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
- """Stores a warning log message.
+ """Log a warning message.
:param message: The message to be logged
:type message: str
@@ -73,7 +127,7 @@ class DistributedLogger:
self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
- """Stores a debug log message.
+ """Log a debug message.
:param message: The message to be logged
:type message: str
@@ -85,7 +139,7 @@ class DistributedLogger:
self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
- """Stores an error log message.
+ """Log an error message.
:param message: The message to be logged
:type message: str
diff --git a/colossalai/nn/__init__.py b/colossalai/nn/__init__.py
index 69fd61594..c612b631a 100644
--- a/colossalai/nn/__init__.py
+++ b/colossalai/nn/__init__.py
@@ -1,4 +1,3 @@
-from .data import *
from .layer import *
from .loss import *
from .lr_scheduler import *
diff --git a/colossalai/nn/data/__init__.py b/colossalai/nn/data/__init__.py
deleted file mode 100644
index d94afe2da..000000000
--- a/colossalai/nn/data/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .caltech101_dataset import Caltech101Dataset
-from .cifar10_dataset import CIFAR10Dataset
-from .sampler import *
diff --git a/colossalai/nn/data/_utils.py b/colossalai/nn/data/_utils.py
deleted file mode 100644
index 08d77e0da..000000000
--- a/colossalai/nn/data/_utils.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import numpy as np
-
-
-def pil_img_to_numpy(pil_img):
- """convert a PIL image to numpy nd-array
-
- :param pil_img: a PIL image
- :type pil_img: PIL.Image
- :return: a nd-array
- :rtype: numpy.ndarray
- """
- np_img = np.array(pil_img)
- np_img = np.rollaxis(np_img, 2) # HWC to CHW
- return np_img
diff --git a/colossalai/nn/data/base_dataset.py b/colossalai/nn/data/base_dataset.py
deleted file mode 100644
index 730b37649..000000000
--- a/colossalai/nn/data/base_dataset.py
+++ /dev/null
@@ -1,17 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from abc import ABC
-
-from torch.utils.data import Dataset
-from torchvision.transforms import transforms
-
-from colossalai.builder import build_transform
-
-
-class BaseDataset(Dataset, ABC):
-
- def __init__(self, transform_pipeline: list):
- transform_list = [build_transform(cfg) for cfg in transform_pipeline]
- transform = transforms.Compose(transform_list)
- self._transform_pipeline = transform
diff --git a/colossalai/nn/data/caltech101_dataset.py b/colossalai/nn/data/caltech101_dataset.py
deleted file mode 100644
index b1dc89b68..000000000
--- a/colossalai/nn/data/caltech101_dataset.py
+++ /dev/null
@@ -1,43 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch.distributed as dist
-from torchvision.datasets import Caltech101
-
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.registry import DATASETS
-from .base_dataset import BaseDataset
-
-
-@DATASETS.register_module
-class Caltech101Dataset(BaseDataset):
- """`Caltech 101 `_ Dataset.
-
- :param transform_pipeline: A list of functions' config, which takes in an PIL image
- and returns a transformed version
- :type transform_pipeline: list
- """
-
- def __init__(self, transform_pipeline: list, *args, **kwargs):
- super().__init__(transform_pipeline)
- if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
- dist.barrier()
- self._dataset = Caltech101(
- transform=self._transform_pipeline, *args, **kwargs)
- if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
- dist.barrier()
-
- def __len__(self):
- return len(self._dataset)
-
- def __getitem__(self, item):
- """
-
- :param item: Index
- :type item: int
- :return: ((image,), (target,)) where the type of target specified by target_type.
- :rtype: tuple
- """
- img, label = self._dataset.__getitem__(item)
- return (img,), (label,)
diff --git a/colossalai/nn/data/cifar10_dataset.py b/colossalai/nn/data/cifar10_dataset.py
deleted file mode 100644
index a0ce139a2..000000000
--- a/colossalai/nn/data/cifar10_dataset.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch.distributed as dist
-from torchvision.datasets import CIFAR10
-
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.registry import DATASETS
-from .base_dataset import BaseDataset
-
-
-@DATASETS.register_module
-class CIFAR10Dataset(BaseDataset):
- """`CIFAR10 `_ Dataset.
-
- :param transform_pipeline: A list of functions' config, which takes in an PIL image
- and returns a transformed version
- :type transform_pipeline: list
- """
-
- def __init__(self, transform_pipeline: list, *args, **kwargs):
- super().__init__(transform_pipeline)
- if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
- dist.barrier()
- self._dataset = CIFAR10(transform=self._transform_pipeline,
- *args,
- **kwargs)
- if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
- dist.barrier()
-
- def __len__(self):
- return len(self._dataset)
-
- def __getitem__(self, item):
- """
-
- :param item: Index
- :type item: int
- :return: ((image,), (target,)) where the type of target specified by target_type.
- :rtype: tuple
- """
- img, label = self._dataset.__getitem__(item)
- return (img,), (label,)
diff --git a/colossalai/nn/data/sampler/__init__.py b/colossalai/nn/data/sampler/__init__.py
deleted file mode 100644
index 471add313..000000000
--- a/colossalai/nn/data/sampler/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .base_sampler import BaseSampler
-from .data_parallel_sampler import DataParallelSampler
-
-__all__ = ['BaseSampler', 'DataParallelSampler']
diff --git a/colossalai/nn/init.py b/colossalai/nn/init.py
new file mode 100644
index 000000000..057cc008d
--- /dev/null
+++ b/colossalai/nn/init.py
@@ -0,0 +1,33 @@
+import math
+
+from torch import Tensor
+from torch.nn import init as init
+
+
+def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
+ if init_method == 'torch':
+ a = math.sqrt(5)
+ nonlinearity = 'leaky_relu'
+ std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
+ bound = math.sqrt(3.0) * std
+ init.uniform_(tensor, -bound, bound)
+ elif init_method == 'jax':
+ std = math.sqrt(2.0 / float(fan_in + fan_out))
+ a = math.sqrt(3.0) * std
+ init.uniform_(tensor, -a, a)
+ elif init_method == 'jax_embed':
+ std = math.sqrt(1.0 / fan_in)
+ init.trunc_normal_(tensor, std=std / .87962566103423978)
+ elif init_method == 'zero':
+ init.zeros_(tensor)
+
+def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
+ if init_method == 'torch':
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ init.uniform_(tensor, -bound, bound)
+ elif init_method == 'jax':
+ init.normal_(tensor, std=1e-6)
+ elif init_method == 'jax_embed':
+ init.trunc_normal_(tensor, std=.02)
+ elif init_method == 'zero':
+ init.zeros_(tensor)
diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py
index 1456a8a56..e56d8bffe 100644
--- a/colossalai/nn/layer/__init__.py
+++ b/colossalai/nn/layer/__init__.py
@@ -1,9 +1,8 @@
+from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
-from .parallel_vision_transformer import *
-from .vanilla_resnet import *
-from .vanilla_vision_transformer import *
+from .non_parallel_layers import *
from .wrapper import *
diff --git a/colossalai/nn/layer/_common_utils.py b/colossalai/nn/layer/_common_utils.py
index 69f63ea5a..759b09003 100644
--- a/colossalai/nn/layer/_common_utils.py
+++ b/colossalai/nn/layer/_common_utils.py
@@ -2,40 +2,14 @@
# -*- encoding: utf-8 -*-
import math
-
+import collections.abc
+from itertools import repeat
+import numpy as np
+from colossalai.utils.common import print_rank_0
import torch
-from torch import Tensor
-from torch import nn
+from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.utils import checkpoint
-
-from colossalai.constants import IS_TENSOR_PARALLEL
-
-
-def divide(numerator, denominator):
- """ only allow exact division """
- assert numerator % denominator == 0, \
- '{} is not divisible by {}'.format(numerator, denominator)
- return numerator // denominator
-
-
-def gelu(x: Tensor) -> Tensor:
- """Implementation of the gelu activation function.
- For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
- 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
- """
- return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
-
-
-def swish(x: Tensor) -> Tensor:
- return x * torch.sigmoid(x)
-
-
-ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
-
-
-def set_tensor_parallel_attribute(param):
- if not hasattr(param, IS_TENSOR_PARALLEL):
- setattr(param, IS_TENSOR_PARALLEL, True)
+from torch import Tensor, nn
class CheckpointModule(nn.Module):
@@ -44,15 +18,15 @@ class CheckpointModule(nn.Module):
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
- def _forward(self, *args):
+ def _forward(self, *args, **kwargs):
raise NotImplementedError(
'CheckpointModule should implement _forward method instead of origin forward')
- def forward(self, *args):
+ def forward(self, *args, **kwargs):
if self._use_checkpoint:
- return checkpoint(self._forward, *args)
+ return checkpoint(self._forward, *args, **kwargs)
else:
- return self._forward(*args)
+ return self._forward(*args, **kwargs)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
@@ -61,3 +35,38 @@ class CheckpointModule(nn.Module):
def eval(self):
self._use_checkpoint = False
return super().eval()
+
+def divide(numerator, denominator):
+ """ only allow exact division """
+ assert numerator % denominator == 0, \
+ '{} is not divisible by {}'.format(numerator, denominator)
+ return numerator // denominator
+
+
+def swish(x: Tensor) -> Tensor:
+ return x * torch.sigmoid(x)
+
+
+ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
+
+
+def set_tensor_parallel_attribute_by_size(param, size):
+ setattr(param, IS_TENSOR_PARALLEL, True)
+ setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))
+
+
+def set_tensor_parallel_attribute_by_partition(param, num_partitions):
+ setattr(param, IS_TENSOR_PARALLEL, True)
+ setattr(param, NUM_PARTITIONS, num_partitions)
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_2tuple = _ntuple(2)
diff --git a/colossalai/nn/layer/fused_bias_gelu.py b/colossalai/nn/layer/fused_bias_gelu.py
new file mode 100644
index 000000000..e92041534
--- /dev/null
+++ b/colossalai/nn/layer/fused_bias_gelu.py
@@ -0,0 +1,35 @@
+# adapted from Megatron-LM
+# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
+
+import torch
+
+@torch.jit.script
+def bias_gelu(bias, y):
+ x = bias + y
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
+
+# gradient of tanh approximation of gelu
+# gradient of actual gelu is:
+# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
+@torch.jit.script
+def bias_gelu_back(g, bias, y):
+ x = bias + y
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
+ return ff*g
+
+class GeLUFunction(torch.autograd.Function):
+ @staticmethod
+ # bias is an optional argument
+ def forward(ctx, input, bias):
+ ctx.save_for_backward(input, bias)
+ return bias_gelu(bias, input)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, bias = ctx.saved_tensors
+ tmp = bias_gelu_back(grad_output, bias, input)
+ return tmp, tmp
+
+bias_gelu_impl = GeLUFunction.apply
\ No newline at end of file
diff --git a/colossalai/nn/layer/non_parallel_layers/__init__.py b/colossalai/nn/layer/non_parallel_layers/__init__.py
new file mode 100644
index 000000000..6a9883141
--- /dev/null
+++ b/colossalai/nn/layer/non_parallel_layers/__init__.py
@@ -0,0 +1,8 @@
+from ._vit import (ViTBlock, VanillaViTAttention, VanillaViTBlock, VanillaViTDropPath,
+ VanillaViTHead, VanillaViTMLP, VanillaViTPatchEmbedding)
+
+
+__all__ = [
+ 'ViTBlock', 'VanillaViTAttention', 'VanillaViTBlock', 'VanillaViTDropPath',
+ 'VanillaViTHead', 'VanillaViTMLP', 'VanillaViTPatchEmbedding'
+]
diff --git a/colossalai/nn/layer/vanilla_vision_transformer/layers.py b/colossalai/nn/layer/non_parallel_layers/_vit.py
similarity index 88%
rename from colossalai/nn/layer/vanilla_vision_transformer/layers.py
rename to colossalai/nn/layer/non_parallel_layers/_vit.py
index 6f7ec4c7c..59a12fee2 100644
--- a/colossalai/nn/layer/vanilla_vision_transformer/layers.py
+++ b/colossalai/nn/layer/non_parallel_layers/_vit.py
@@ -1,23 +1,47 @@
-import collections.abc
-from itertools import repeat
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
import torch
from torch import nn as nn
+from colossalai.builder import build_layer
from colossalai.registry import LAYERS
+from .._common_utils import to_2tuple
-# From PyTorch internals
-def _ntuple(n):
- def parse(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return tuple(repeat(x, n))
+@LAYERS.register_module
+class ViTBlock(nn.Module):
+ """Vision Transformer block
- return parse
+ :param attention_cfg: config of attention layer
+ :type attention_cfg: dict
+ :param droppath_cfg: config of drop path
+ :type droppath_cfg: dict
+ :param mlp_cfg: config of MLP layer
+ :type mlp_cfg: dict
+ :param norm_cfg: config of normlization layer
+ :type norm_cfg: dict
+ """
+ def __init__(self,
+ attention_cfg: dict,
+ droppath_cfg: dict,
+ mlp_cfg: dict,
+ norm_cfg: dict,
+ ):
+ super().__init__()
+ self.norm1 = build_layer(norm_cfg)
+ self.attn = build_layer(attention_cfg)
+ self.drop_path = build_layer(
+ droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
+ self.norm2 = build_layer(norm_cfg)
+ self.mlp = build_layer(mlp_cfg)
-to_2tuple = _ntuple(2)
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
@LAYERS.register_module
diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py
index 9e7df549f..85272d7c0 100644
--- a/colossalai/nn/layer/parallel_1d/__init__.py
+++ b/colossalai/nn/layer/parallel_1d/__init__.py
@@ -1,5 +1,11 @@
from .layers import Linear1D_Col, Linear1D_Row
+from .layers import MixedFusedLayerNorm1D as LayerNorm1D
+from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D
+from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHead
+
+
__all__ = [
- 'Linear1D_Col', 'Linear1D_Row',
+ 'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D',
+ 'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead'
]
diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py
new file mode 100644
index 000000000..aee28926a
--- /dev/null
+++ b/colossalai/nn/layer/parallel_1d/_operation.py
@@ -0,0 +1,34 @@
+import torch
+
+try:
+ import fused_mix_prec_layer_norm_cuda
+except:
+ fused_mix_prec_layer_norm_cuda = None
+
+
+class FusedLayerNormAffineFunction1D(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, weight, bias, normalized_shape, eps):
+ ctx.normalized_shape = normalized_shape
+ ctx.eps = eps
+ input_ = input.contiguous()
+ weight_ = weight.contiguous()
+ bias_ = bias.contiguous()
+ output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
+ input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
+ ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
+ return output
+
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight_, bias_, mean, invvar = ctx.saved_tensors
+ grad_input = grad_weight = grad_bias = None
+ grad_input, grad_weight, grad_bias \
+ = fused_mix_prec_layer_norm_cuda.backward_affine(
+ grad_output.contiguous(), mean, invvar,
+ input_, ctx.normalized_shape,
+ weight_, bias_, ctx.eps)
+
+ return grad_input, grad_weight, grad_bias, None, None
\ No newline at end of file
diff --git a/colossalai/nn/layer/parallel_1d/_transformer.py b/colossalai/nn/layer/parallel_1d/_transformer.py
new file mode 100644
index 000000000..90a8d740e
--- /dev/null
+++ b/colossalai/nn/layer/parallel_1d/_transformer.py
@@ -0,0 +1,220 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+import math
+from torch import Tensor
+from torch.nn.parameter import Parameter
+from typing import Tuple
+
+from colossalai.context import seed, ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.registry import LAYERS
+from colossalai.utils import get_current_device
+from .._common_utils import divide, ACT2FN
+from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
+ split_forward_gather_backward
+from ..base_layer import ParallelLayer
+from .layers import Linear1D_Col, Linear1D_Row
+from .layers import MixedFusedLayerNorm1D as LayerNorm1D
+
+@LAYERS.register_module
+class TransformerMLP1D(ParallelLayer):
+ """MLP.
+ MLP will take the input with h hidden state, project it to 4*h
+ hidden dimension, perform nonlinear transformation, and project the
+ state back into h hidden dimension.
+ """
+
+ def __init__(self,
+ in_features: int,
+ mlp_ratio: int = 4.0,
+ act_func: str = 'gelu',
+ dropout_prob: float = 0.,
+ dtype=None,
+ skip_bias_add: bool = False
+ ):
+ super(TransformerMLP1D, self).__init__()
+ self.in_features = in_features
+ self.mlp_ratio = mlp_ratio
+ self.skip_bias_add = skip_bias_add
+ # Project to h * mlp_ratio.
+ self.dense_1 = Linear1D_Col(
+ self.in_features,
+ int(self.mlp_ratio * self.in_features),
+ bias=not skip_bias_add,
+ dtype=dtype,
+ gather_output = False,
+ )
+
+ assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
+ f'activation function can only be {list(ACT2FN.keys())}'
+ self.activation_func = ACT2FN[act_func]
+
+ # Project back to h.
+ self.dense_2 = Linear1D_Row(
+ int(self.mlp_ratio * self.in_features),
+ self.in_features,
+ bias=not skip_bias_add,
+ dtype=dtype,
+ parallel_input = True,
+ )
+ self.dropout = nn.Dropout(dropout_prob)
+ # self.layernorm = LayerNorm1D(in_features, dtype=dtype)
+ self.layernorm = nn.LayerNorm(in_features, dtype=dtype)
+ def forward(self, x):
+ if self.skip_bias_add:
+ intermediate_output, _ = self.dense_1(x)
+ else:
+ intermediate_output = self.dense_1(x)
+
+ intermediate_output = self.activation_func(intermediate_output)
+
+ if self.skip_bias_add:
+ output, _ = self.dense_2(intermediate_output)
+ else:
+ output = self.dense_2(intermediate_output)
+
+ with seed(ParallelMode.TENSOR):
+ output = self.dropout(output)
+ output = self.layernorm(x + output)
+ return output
+
+@LAYERS.register_module
+class TransformerSelfAttention1D(ParallelLayer):
+ """Self attention layer for 1D parallel Transformer
+
+ :param hidden_size: hidden size
+ :type hidden_size: int
+ :param num_attention_heads: number of attention heads
+ :type num_attention_heads: int
+ :param attention_dropout_prob: dropout probability for attention layer
+ :type attention_dropout_prob: float
+ :param hidden_dropout_prob: dropout probability for hidden layer
+ :type hidden_dropout_prob: float
+ :param dtype: dtype of parameters, defaults to None
+ :type dtype: torch.dtype, optional
+ """
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ attention_dropout_prob: float,
+ hidden_dropout_prob: float,
+ dtype=None,
+ ):
+
+ super().__init__()
+
+ self.hidden_size = hidden_size
+
+ self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size)
+ self.attention_head_size = divide(hidden_size, num_attention_heads)
+ self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
+
+ self.query_key_value = Linear1D_Col(
+ hidden_size,
+ 3 * hidden_size,
+ dtype=dtype,
+ )
+ self.attention_dropout = nn.Dropout(attention_dropout_prob)
+ self.dense = Linear1D_Row(
+ hidden_size,
+ hidden_size,
+ dtype=dtype,
+ parallel_input=True,
+ )
+ self.dropout = nn.Dropout(hidden_dropout_prob)
+
+ # need to re-enable torch grad to enable fused optimization.
+ # self.layernorm = LayerNorm1D(
+ # hidden_size,
+ # dtype=dtype)
+ self.layernorm = nn.LayerNorm(
+ hidden_size,
+ dtype=dtype)
+
+ def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
+ query_key_value = self.query_key_value(hidden_states)
+ new_qkv_shape = query_key_value.shape[:-1] + \
+ (self.num_attention_heads, 3 * self.attention_head_size)
+ query_key_value = query_key_value.view(new_qkv_shape)
+ query_key_value = query_key_value.permute((0, 2, 1, 3))
+ query_layer, key_layer, value_layer = torch.chunk(
+ query_key_value, 3, dim=-1)
+
+ attention_scores = torch.matmul(
+ query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / \
+ math.sqrt(self.attention_head_size)
+ attention_scores = attention_scores + attention_mask
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+ with seed(ParallelMode.TENSOR):
+ attention_probs = self.attention_dropout(attention_probs)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
+ new_context_layer_shape = context_layer.size()[
+ :-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ output = self.dense(context_layer)
+ with seed(ParallelMode.TENSOR):
+ output = self.dropout(output)
+ attention_output = self.layernorm(hidden_states + output)
+
+ return attention_output
+
+@LAYERS.register_module
+class TransformerLayer1D(ParallelLayer):
+ """Transformer layer which contains a self-attention layer and a MLP layer
+
+ :param hidden_size: hidden size
+ :type hidden_size: int
+ :param num_attention_heads: number of attention heads
+ :type num_attention_heads: int
+ :param act_func: activation function, defaults to 'gelu'
+ :type act_func: str, optional
+ :param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
+ :type mlp_ratio: float, optional
+ :param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
+ :type attention_dropout_prob: float, optional
+ :param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
+ :type hidden_dropout_prob: float, optional
+ :param dtype: dtype of parameters, defaults to None
+ :type dtype: torch.dtype, optional
+ """
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ act_func: str = 'gelu',
+ mlp_ratio: float = 4.0,
+ attention_dropout_prob: float = 0.,
+ hidden_dropout_prob: float = 0.,
+ dtype=None,
+ ):
+ super().__init__()
+
+ self.attention = TransformerSelfAttention1D(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ attention_dropout_prob=attention_dropout_prob,
+ hidden_dropout_prob=hidden_dropout_prob,
+ dtype=dtype,
+ )
+ self.mlp = TransformerMLP1D(
+ in_features=hidden_size,
+ dropout_prob=hidden_dropout_prob,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ dtype=dtype,
+ )
+
+ def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
+ attention_output = self.attention(hidden_states, attention_mask)
+ output = self.mlp(attention_output)
+ return output
diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/nn/layer/parallel_1d/_utils.py
index 00d221e78..3e1afa186 100644
--- a/colossalai/nn/layer/parallel_1d/_utils.py
+++ b/colossalai/nn/layer/parallel_1d/_utils.py
@@ -13,3 +13,6 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
+
+
+
diff --git a/colossalai/nn/layer/parallel_1d/_vit.py b/colossalai/nn/layer/parallel_1d/_vit.py
new file mode 100644
index 000000000..dca3d1768
--- /dev/null
+++ b/colossalai/nn/layer/parallel_1d/_vit.py
@@ -0,0 +1,411 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import math
+from colossalai import context
+
+import torch
+from torch import nn as nn, Tensor, distributed as dist
+from torch.nn.init import _calculate_fan_in_and_fan_out
+
+from colossalai.context import seed, ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.nn.layer._common_utils import divide, ACT2FN
+from colossalai.registry import LAYERS
+from colossalai.utils import checkpoint
+from colossalai.utils import get_current_device
+from .layers import Linear1D_Col, Linear1D_Row
+from ..base_layer import ParallelLayer
+from .._common_utils import to_2tuple
+from ..fused_bias_gelu import bias_gelu_impl
+
+
+@LAYERS.register_module
+class ViTMLP1D(ParallelLayer):
+ """MLP layer for 1D parallel Vision Transformer
+
+ :param in_features: size of each input sample
+ :type in_features: int
+ :param mlp_ratio: hidden size of MLP divided by embedding dim
+ :type mlp_ratio: int
+ :param act_func: activation function, defaults to 'gelu'
+ :type act_func: str, optional
+ :param dropout_prob: dropout probability, defaults to 0.
+ :type dropout_prob: float, optional
+ :param dtype: The dtype of parameters, defaults to None
+ :type dtype: torch.dtype, optional
+ :param checkpoint: whether to checkpoint the layer, defaults to False
+ :type checkpoint: bool, optional
+ """
+
+ def __init__(self,
+ in_features: int,
+ mlp_ratio: int,
+ act_func: str = 'gelu',
+ dropout_prob: float = 0.,
+ dtype=None,
+ checkpoint: bool = False,
+ skip_bias_add: bool = False,
+ weight_init='torch'
+ ):
+ super().__init__()
+
+ self.in_features = in_features
+ self.mlp_ratio = mlp_ratio
+ self.checkpoint = checkpoint
+ self.skip_bias_add = skip_bias_add
+ assert weight_init in ('torch', 'jax')
+
+ if act_func == 'fused_gelu':
+ self.act = bias_gelu_impl
+ skip_dense_1_add_bias = True
+ else:
+ self.act = ACT2FN[act_func]
+ skip_dense_1_add_bias = False
+
+ # Project to mlp_ratio * h.
+ self.dense_1 = Linear1D_Col(
+ self.in_features,
+ int(self.mlp_ratio * self.in_features),
+ dtype=dtype,
+ gather_output=False,
+ skip_bias_add=skip_dense_1_add_bias,
+ init_weight=weight_init,
+ init_bias=weight_init
+ )
+
+ # Project back to h.
+ self.dense_2 = Linear1D_Row(
+ int(self.mlp_ratio * self.in_features),
+ self.in_features,
+ dtype=dtype,
+ parallel_input=True,
+ init_weight=weight_init, init_bias=weight_init
+ )
+
+ self.dropout = nn.Dropout(dropout_prob)
+
+ def _forward(self, hidden_states: Tensor) -> Tensor:
+ if self.act == bias_gelu_impl:
+ intermediate_output, bias = self.dense_1(hidden_states)
+ intermediate_output = self.act(intermediate_output, bias)
+ else:
+ intermediate_output = self.dense_1(hidden_states)
+ intermediate_output = self.act(intermediate_output)
+
+ with seed(ParallelMode.TENSOR):
+ intermediate_output = self.dropout(intermediate_output)
+ output = self.dense_2(intermediate_output)
+ output = self.dropout(output)
+ return output
+
+ def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
+ return checkpoint(self._forward, hidden_states)
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ if self.checkpoint:
+ return self._checkpoint_forward(hidden_states)
+ else:
+ return self._forward(hidden_states)
+
+
+@LAYERS.register_module
+class ViTSelfAttention1D(ParallelLayer):
+ """Self-attention layer for 1D parallel Vision Transformer
+
+ :param hidden_size: hidden size
+ :type hidden_size: int
+ :param num_attention_heads: number of attention heads
+ :type num_attention_heads: int
+ :param attention_dropout_prob: dropout probability for attention layers
+ :type attention_dropout_prob: float
+ :param hidden_dropout_prob: dropout probability for hidden layers
+ :type hidden_dropout_prob: float
+ :param dtype: dtype of parameters, defaults to None
+ :type dtype: torch.dtype, optional
+ :param checkpoint: whether to checkpoint the layer, defaults to False
+ :type checkpoint: bool, optional
+ """
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ attention_dropout_prob: float,
+ hidden_dropout_prob: float,
+ dtype=None,
+ checkpoint: bool = False,
+ weight_init='torch'
+ ):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.attention_head_size = divide(hidden_size, num_attention_heads)
+ self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
+ self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
+
+ self.checkpoint = checkpoint
+ assert weight_init in ('torch', 'jax')
+ if weight_init == 'jax':
+ init_bias = 'zero'
+ else:
+ init_bias = weight_init
+
+ self.query_key_value = Linear1D_Col(
+ hidden_size,
+ 3 * hidden_size,
+ dtype=dtype,
+ init_weight=weight_init,
+ init_bias=init_bias
+ )
+ self.attention_dropout = nn.Dropout(attention_dropout_prob)
+ self.dense = Linear1D_Row(
+ hidden_size,
+ hidden_size,
+ dtype=dtype,
+ parallel_input=True,
+ init_weight=weight_init, init_bias=init_bias
+ )
+ self.dropout = nn.Dropout(hidden_dropout_prob)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def _forward(self, hidden_states: Tensor) -> Tensor:
+ query_key_value = self.query_key_value(hidden_states)
+ new_qkv_shape = query_key_value.shape[:-1] + \
+ (self.num_attention_heads_per_partition, 3 * self.attention_head_size)
+ query_key_value = query_key_value.view(new_qkv_shape)
+ query_key_value = query_key_value.permute((0, 2, 1, 3))
+ query_layer, key_layer, value_layer = torch.chunk(
+ query_key_value, 3, dim=-1)
+
+ attention_scores = torch.matmul(
+ query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / \
+ math.sqrt(self.attention_head_size)
+
+ attention_probs = self.softmax(attention_scores)
+
+ with seed(ParallelMode.TENSOR):
+ attention_probs = self.attention_dropout(attention_probs)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.transpose(1, 2)
+ new_context_layer_shape = context_layer.size()[
+ :-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+ output = self.dense(context_layer)
+ output = self.dropout(output)
+
+ return output
+
+ def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
+ return checkpoint(self._forward, hidden_states)
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ if self.checkpoint:
+ return self._checkpoint_forward(hidden_states)
+ else:
+ return self._forward(hidden_states)
+
+
+@LAYERS.register_module
+class ViTHead1D(ParallelLayer):
+ """Output layer for 1D parallel Vision Transformer
+
+ :param hidden_size: hidden size
+ :type hidden_size: int
+ :param num_classes: number of classes
+ :type num_classes: int
+ :param dtype: dtype of parameters, defaults to None
+ :type dtype: torch.dtype, optional
+ """
+
+ def __init__(self,
+ hidden_size,
+ num_classes,
+ dtype=None,
+ weight_init='torch'
+ ):
+ super().__init__()
+
+ assert weight_init in ('torch', 'jax')
+ if weight_init == 'jax':
+ init_weight = 'zero'
+ init_bias = 'zero'
+ else:
+ init_weight = weight_init
+ init_bias = weight_init
+
+ self.linear = Linear1D_Col(
+ hidden_size,
+ num_classes,
+ dtype=dtype,
+ gather_output=True,
+ init_weight=init_weight,
+ init_bias=init_bias
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = x[:, 0]
+ x = self.linear(x)
+ return x
+
+
+@LAYERS.register_module
+class ViTHead(ParallelLayer):
+ """Output layer for 1D parallel Vision Transformer
+
+ :param hidden_size: hidden size
+ :type hidden_size: int
+ :param num_classes: number of classes
+ :type num_classes: int
+ :param dtype: dtype of parameters, defaults to None
+ :type dtype: torch.dtype, optional
+ """
+
+ def __init__(self,
+ hidden_size,
+ num_classes,
+ dtype=None,
+ ):
+ super().__init__()
+ self.linear = nn.Linear(
+ hidden_size,
+ num_classes,
+ dtype=dtype
+ )
+ self._broadcast_linear_params()
+
+ def _broadcast_linear_params(self) -> None:
+ self.to(get_current_device())
+ ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
+
+ dist.broadcast(self.linear.weight, src=ranks[0],
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+ dist.broadcast(self.linear.bias, src=ranks[0],
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = x[:, 0]
+ x = self.linear(x)
+ return x
+
+
+@LAYERS.register_module
+class ViTPatchEmbedding1D(ParallelLayer):
+ """ 2D Image to Patch Embedding
+
+ :param img_size: iamge size
+ :type img_size: int
+ :param patch_size: patch size
+ :type patch_size: int
+ :param embed_dim: dimension of embedding
+ :type embed_dim: int
+ :param in_chans: number of channels of input image, defaults to 3
+ :type in_chans: int, optional
+ :param flatten: whether to flatten output tensor, defaults to True
+ :type flatten: bool, optional
+ """
+
+ def __init__(self,
+ img_size,
+ patch_size,
+ embed_dim,
+ in_chans=3,
+ flatten=True,
+ weight_init='torch'):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0],
+ img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans,
+ self.embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size
+ )
+
+ if weight_init == 'jax':
+ fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
+ std = math.sqrt(1.0 / fan_in)
+ nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
+ nn.init.zeros_(self.proj.bias)
+
+ # sync
+ self._broadcast_conv_params()
+
+ def _broadcast_conv_params(self) -> None:
+ self.to(get_current_device())
+ ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
+
+ dist.broadcast(self.proj.weight, src=ranks[0],
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+ dist.broadcast(self.proj.bias, src=ranks[0],
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, C, H, W = x.shape
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ return x
+
+
+@LAYERS.register_module
+class ViTTokenFuser1D(ParallelLayer):
+ """
+ Fuse cls token and pos embedding to the input
+
+ :param img_size: image size
+ :type img_size: int
+ :param patch_size: patch size
+ :type patch_size: int
+ :param embed_dim: dimension of embedding
+ :type embed_dim: int
+ :param drop_rate: dropout probability, defaults to 0.
+ :type drop_rate: float, optional
+ """
+
+ def __init__(self,
+ img_size,
+ patch_size,
+ embed_dim,
+ drop_rate=0.
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0],
+ img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.embed_dim = embed_dim
+
+ self.cls_token = nn.Parameter(torch.zeros(
+ 1, 1, self.embed_dim))
+ self.pos_embed = nn.Parameter(torch.empty(
+ 1, self.num_patches + 1, self.embed_dim))
+ nn.init.trunc_normal_(self.pos_embed, std=.02)
+
+ # move to cuda before broadcast
+ self.to(get_current_device())
+ dist.broadcast(self.pos_embed,
+ src=gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
+ group=gpc.get_group(ParallelMode.TENSOR))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ def forward(self, x: Tensor) -> Tensor:
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_token, x), dim=1)
+ x = self.pos_drop(x + self.pos_embed)
+ return x.contiguous()
diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py
index 572eca777..796e04386 100644
--- a/colossalai/nn/layer/parallel_1d/layers.py
+++ b/colossalai/nn/layer/parallel_1d/layers.py
@@ -1,24 +1,30 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+import math
+import numbers
import torch
+import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
+import importlib
-from colossalai.context.parallel_mode import ParallelMode
+from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
-from .._common_utils import divide
+from ._operation import FusedLayerNormAffineFunction1D
+from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
+@LAYERS.register_module
class Linear1D_Col(ParallelLayer):
"""Linear layer with column parallelism.
@@ -44,23 +50,29 @@ class Linear1D_Col(ParallelLayer):
output_size: int,
bias: bool = True,
dtype: torch.dtype = None,
- gather_output: bool = False):
+ gather_output: bool = False,
+ skip_bias_add: bool = False,
+ init_weight='torch',
+ init_bias='torch'
+ ):
super().__init__()
# Keep input parameters
- self.input_size = in_features
- self.output_size = output_size
+ self.in_features = in_features
+ self.out_features = output_size
self.gather_output = gather_output
- self.skip_bias_add = not bias
+ self.skip_bias_add = skip_bias_add
- world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
- self.output_size_per_partition = divide(output_size, world_size)
+ if skip_bias_add and not bias:
+ raise ValueError('cannot skip bias addition if bias is None')
+
+ self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
- self.output_size_per_partition, self.input_size,
+ self.output_size_per_partition, self.in_features,
**factory_kwargs))
if bias:
@@ -72,6 +84,45 @@ class Linear1D_Col(ParallelLayer):
self.bias.zero_()
else:
self.register_parameter('bias', None)
+ with seed(ParallelMode.TENSOR):
+ self.reset_parameters(init_weight, init_bias)
+ self._set_tensor_parallel_attributes()
+
+ def reset_parameters(self, init_weight, init_bias) -> None:
+ assert init_weight in ('torch', 'jax', 'zero')
+ assert init_bias in ('torch', 'jax', 'zero')
+ # setting
+ fan_in, fan_out = self.in_features, self.out_features
+
+ # init weight
+ if init_weight == 'torch':
+ a = math.sqrt(5)
+ nonlinearity = 'leaky_relu'
+ std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
+ bound = math.sqrt(3.0) * std
+ init.uniform_(self.weight, -bound, bound)
+ elif init_weight == 'jax':
+ std = math.sqrt(2.0 / float(fan_in + fan_out))
+ a = math.sqrt(3.0) * std
+ init.uniform_(self.weight, -a, a)
+ elif init_weight == 'zero':
+ init.zeros_(self.weight)
+
+ # init bias
+ if self.bias is not None:
+ if init_bias == 'torch':
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ init.uniform_(self.bias, -bound, bound)
+ elif init_bias == 'jax':
+ init.normal_(self.bias, std=1e-6)
+ elif init_bias == 'zero':
+ init.zeros_(self.bias)
+
+ def _set_tensor_parallel_attributes(self):
+ num_partition = gpc.get_world_size(ParallelMode.TENSOR)
+ set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
+ if self.bias is not None:
+ set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Set up backprop all-reduce.
@@ -104,7 +155,7 @@ class Linear1D_Row(ParallelLayer):
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
- :param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False
+ :param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional
"""
@@ -113,7 +164,10 @@ class Linear1D_Row(ParallelLayer):
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
- parallel_input: bool = False
+ parallel_input: bool = False,
+ skip_bias_add: bool = False,
+ init_weight='torch',
+ init_bias='torch'
):
super().__init__()
@@ -121,11 +175,13 @@ class Linear1D_Row(ParallelLayer):
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
- self.skip_bias_add = not bias
+ self.skip_bias_add = skip_bias_add
+
+ if skip_bias_add and not bias:
+ raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
- world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
- self.input_size_per_partition = divide(in_features, world_size)
+ self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
@@ -146,9 +202,46 @@ class Linear1D_Row(ParallelLayer):
self.bias.zero_()
else:
self.register_parameter('bias', None)
+ with seed(ParallelMode.TENSOR):
+ self.reset_parameters(init_weight, init_bias)
+ self._set_tensor_parallel_attributes()
- def reset_parameters(self) -> None:
- init.xavier_normal_(self.weight)
+ def reset_parameters(self, init_weight, init_bias) -> None:
+ assert init_weight in ('torch', 'jax', 'zero')
+ assert init_bias in ('torch', 'jax', 'zero')
+ # setting
+ fan_in, fan_out = self.in_features, self.out_features
+
+ # init weight
+ if init_weight == 'torch':
+ a = math.sqrt(5)
+ nonlinearity = 'leaky_relu'
+ std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
+ bound = math.sqrt(3.0) * std
+ init.uniform_(self.weight, -bound, bound)
+ elif init_weight == 'jax':
+ std = math.sqrt(2.0 / float(fan_in + fan_out))
+ a = math.sqrt(3.0) * std
+ init.uniform_(self.weight, -a, a)
+ elif init_weight == 'zero':
+ init.zeros_(self.weight)
+
+ # init bias
+ if self.bias is not None:
+ if init_bias == 'torch':
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ init.uniform_(self.bias, -bound, bound)
+ elif init_bias == 'jax':
+ init.normal_(self.bias, std=1e-6)
+ elif init_bias == 'zero':
+ init.zeros_(self.bias)
+ dist.broadcast(self.bias,
+ src=gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0],
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+
+ def _set_tensor_parallel_attributes(self):
+ num_partition = gpc.get_world_size(ParallelMode.TENSOR)
+ set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
@@ -163,4 +256,29 @@ class Linear1D_Row(ParallelLayer):
if not self.skip_bias_add:
output = output + self.bias
- return output
+ return output
+ else:
+ return output, self.bias
+
+
+@LAYERS.register_module
+class MixedFusedLayerNorm1D(torch.nn.Module):
+
+ def __init__(self, normalized_shape, eps=1e-5):
+ super(MixedFusedLayerNorm1D, self).__init__()
+
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ self.normalized_shape = torch.Size(normalized_shape)
+ self.eps = eps
+ self.weight = Parameter(torch.Tensor(*normalized_shape))
+ self.bias = Parameter(torch.Tensor(*normalized_shape))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+ def forward(self, input):
+ return FusedLayerNormAffineFunction1D.apply(
+ input, self.weight, self.bias, self.normalized_shape, self.eps)
diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py
index d9ecf2fad..c3722b43e 100644
--- a/colossalai/nn/layer/parallel_2d/_operation.py
+++ b/colossalai/nn/layer/parallel_2d/_operation.py
@@ -20,7 +20,6 @@ def matmul_2d(a,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
):
"""Matrix multiplication for 2D parallelism
-
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
@@ -86,25 +85,30 @@ class Matmul_AB_2D(torch.autograd.Function):
ctx.save_for_backward(A, B)
A_shape = A.shape
- A = A.reshape((-1, A_shape[-1]))
+ A = A.reshape((-1, A_shape[-1])).contiguous()
B_shape = B.shape
- B = B.reshape((-1, B_shape[-1]))
+ B = B.reshape((-1, B_shape[-1])).contiguous()
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
- for i in range(summa_dim):
- A_temp = A.clone()
- B_temp = B.clone()
- src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
- dist.broadcast(A_temp, src=src_a,
- group=gpc.get_group(row_parallel_mode))
- src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
- dist.broadcast(B_temp, src=src_b,
- group=gpc.get_group(col_parallel_mode))
- torch.addmm(C, A_temp, B_temp, out=C)
+ A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
+ B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
+ A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
+ B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
+ op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
+ op_a.wait()
+ op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
+ for op in [op_a, op_b]:
+ op.wait()
+ for i in range(summa_dim):
+ src_a = i + summa_dim * row_rank
+ src_b = i + summa_dim * col_rank
+ src_a = src_a % summa_dim
+ src_b = src_b % summa_dim
+ A_temp = A_list[src_a]
+ B_temp = B_list[src_b]
+ torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape)
if ctx:
@@ -499,36 +503,61 @@ class _LayerNorm_2D(torch.autograd.Function):
# return input_grad, None, None, None, None, None
-class _ViT_Split_Input_2D(torch.autograd.Function):
+class AllGatherLast(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
- batch_size: int,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
- # inputs: [b, s, h/q]
- # output: [b/q, s, h/q]
-
- ctx.BATCH_SIZE = batch_size
ctx.summa_dim = summa_dim
- ctx.col_parallel_mode = col_parallel_mode
- row_rank = gpc.get_local_rank(col_parallel_mode)
- output = torch.chunk(inputs, summa_dim, dim=0)[row_rank]
- output = output.clone()
- return output
+ ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
+
+ last_dim = summa_dim * inputs.size(-1)
+ outputs_shape = (last_dim,) + inputs.shape[:-1]
+ outputs = torch.empty(
+ outputs_shape, dtype=inputs.dtype, device=get_current_device())
+ dist.all_gather(
+ list(outputs.chunk(summa_dim, dim=0)),
+ inputs.permute(2, 0, 1).contiguous(),
+ group=gpc.get_group(col_parallel_mode)
+ )
+ outputs = outputs.permute(1, 2, 0).contiguous()
+ return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
- # output_grad: [b/q, s, h/q]
- # grads: [b, s, h/q]
- grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
- grads = torch.empty(grads_shape,
- dtype=output_grad.dtype,
- device=get_current_device())
- dist.all_gather(list(grads.chunk(ctx.summa_dim, dim=0)),
- output_grad.contiguous(),
- group=gpc.get_group(ctx.col_parallel_mode))
- return grads, None, None, None
+ grad = output_grad.chunk(ctx.summa_dim, dim=-1)[ctx.row_rank]
+ return grad.contiguous(), None, None
+
+
+class SplitFirst(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx: Any,
+ inputs: Tensor,
+ summa_dim: int,
+ col_parallel_mode: ParallelMode) -> Tensor:
+ ctx.summa_dim = summa_dim
+ ctx.batch_size = inputs.size(0)
+ ctx.para_mode = col_parallel_mode
+ row_rank = gpc.get_local_rank(col_parallel_mode)
+
+ outputs = inputs.chunk(summa_dim, dim=0)[row_rank]
+ return outputs
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
+ grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
+ grad = torch.empty(
+ grad_shape, dtype=output_grad.dtype, device=get_current_device())
+ dist.all_gather(
+ list(grad.chunk(ctx.summa_dim, dim=0)),
+ output_grad.contiguous(),
+ group=gpc.get_group(ctx.para_mode)
+ )
+ return grad, None, None
diff --git a/colossalai/nn/layer/parallel_2d/_vit.py b/colossalai/nn/layer/parallel_2d/_vit.py
index 211de1e9f..70734b345 100644
--- a/colossalai/nn/layer/parallel_2d/_vit.py
+++ b/colossalai/nn/layer/parallel_2d/_vit.py
@@ -5,19 +5,21 @@ import math
import torch
from torch import nn as nn, Tensor, distributed as dist
+from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
-from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
+
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
-from ._operation import _ViT_Split_Input_2D
+from colossalai.core import global_context as gpc
+from ._operation import AllGatherLast, SplitFirst
from .layers import Linear2D
-from .._common_utils import set_tensor_parallel_attribute
+from .._common_utils import set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer
+from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
@@ -44,8 +46,8 @@ class ViTMLP2D(ParallelLayer):
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
- checkpoint: bool = False
- ):
+ checkpoint: bool = False,
+ weight_init='torch'):
super().__init__()
assert_summa_initialization()
@@ -53,27 +55,40 @@ class ViTMLP2D(ParallelLayer):
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
+ assert weight_init in ('torch', 'jax')
+
+ if act_func == 'fused_gelu':
+ self.act = bias_gelu_impl
+ skip_dense_1_add_bias = True
+ else:
+ self.act = ACT2FN[act_func]
+ skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear2D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
+ init_weight=weight_init, init_bias=weight_init,
+ skip_bias_add=skip_dense_1_add_bias
)
- self.act = ACT2FN[act_func]
-
# Project back to h.
self.dense_2 = Linear2D(
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
+ init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
- intermediate_output = self.dense_1(hidden_states)
- intermediate_output = self.act(intermediate_output)
+ if self.act == bias_gelu_impl:
+ intermediate_output, bias = self.dense_1(hidden_states)
+ intermediate_output = self.act(intermediate_output, bias)
+ else:
+ intermediate_output = self.dense_1(hidden_states)
+ intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
@@ -117,8 +132,8 @@ class ViTSelfAttention2D(ParallelLayer):
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
- checkpoint: bool = False
- ):
+ checkpoint: bool = False,
+ weight_init='torch'):
super().__init__()
assert_summa_initialization()
@@ -128,17 +143,24 @@ class ViTSelfAttention2D(ParallelLayer):
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
+ assert weight_init in ('torch', 'jax')
+ if weight_init == 'jax':
+ self.init_bias = 'zero'
+ else:
+ self.init_bias = weight_init
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
+ init_weight=weight_init, init_bias=self.init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
+ init_weight=weight_init, init_bias=self.init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
@@ -146,7 +168,7 @@ class ViTSelfAttention2D(ParallelLayer):
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
- (self.num_attention_heads, 3 * self.attention_head_size)
+ (self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
@@ -155,7 +177,7 @@ class ViTSelfAttention2D(ParallelLayer):
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
- math.sqrt(self.attention_head_size)
+ math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
@@ -165,7 +187,7 @@ class ViTSelfAttention2D(ParallelLayer):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
- :-2] + (self.all_head_size,)
+ :-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
@@ -199,14 +221,22 @@ class ViTHead2D(ParallelLayer):
hidden_size,
num_classes,
dtype=None,
- ):
+ weight_init='torch'):
super().__init__()
assert_summa_initialization()
+ assert weight_init in ('torch', 'jax')
+ if weight_init == 'jax':
+ self.init_weight = 'zero'
+ self.init_bias = 'zero'
+ else:
+ self.init_weight = weight_init
+ self.init_bias = weight_init
self.summa_dim = get_summa_dim_from_env()
self.linear = Linear2D(
hidden_size,
num_classes,
dtype=dtype,
+ init_weight=self.init_weight, init_bias=self.init_bias
)
def forward(self, x: Tensor) -> Tensor:
@@ -236,7 +266,8 @@ class ViTPatchEmbedding2D(ParallelLayer):
patch_size,
embed_dim,
in_chans=3,
- flatten=True):
+ flatten=True,
+ weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
@@ -249,39 +280,28 @@ class ViTPatchEmbedding2D(ParallelLayer):
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
- self.embed_dim = embed_dim // self.summa_dim
+ self.embed_dim = embed_dim // (self.summa_dim ** 2)
with seed(ParallelMode.TENSOR):
- # ensure the partitions are initialized differently
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
- stride=patch_size
+ stride=patch_size,
+ device=get_current_device()
)
+ self._set_tensor_parallel_attribute()
- # sync
- self._broadcast_conv_params()
- self.proj.weight.register_hook(self._sync_grad_during_backward)
- self.proj.bias.register_hook(self._sync_grad_during_backward)
+ if weight_init == 'jax':
+ with seed(ParallelMode.TENSOR):
+ fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
+ std = math.sqrt(1.0 / fan_in)
+ nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
+ nn.init.zeros_(self.proj.bias)
def _set_tensor_parallel_attribute(self):
- set_tensor_parallel_attribute(self.proj.weight)
- set_tensor_parallel_attribute(self.proj.bias)
-
- def _broadcast_conv_params(self) -> None:
- self.to(get_current_device())
- ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
-
- dist.broadcast(self.proj.weight, src=ranks_in_col[0],
- group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
- dist.broadcast(self.proj.bias, src=ranks_in_col[0],
- group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
-
- def _sync_grad_during_backward(self, grad: Tensor) -> None:
- dist.all_reduce(grad, group=gpc.get_group(
- ParallelMode.PARALLEL_2D_COL))
- grad = grad / self.summa_dim
- return grad
+ num_partition = gpc.get_world_size(ParallelMode.TENSOR)
+ set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
+ set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
@@ -293,6 +313,24 @@ class ViTPatchEmbedding2D(ParallelLayer):
return x
+@LAYERS.register_module
+class ViTInputSplitter2D(ParallelLayer):
+ """Split the input tensor for 2D parallel Vision Transformer
+ """
+
+ def __init__(self):
+ super().__init__()
+ assert_summa_initialization()
+ self.summa_dim = get_summa_dim_from_env()
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = AllGatherLast.apply(
+ x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
+ x = SplitFirst.apply(
+ x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
+ return x
+
+
@LAYERS.register_module
class ViTTokenFuser2D(ParallelLayer):
"""
@@ -328,64 +366,32 @@ class ViTTokenFuser2D(ParallelLayer):
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
- 1, 1, self.embed_dim // self.summa_dim))
- self.pos_embed = nn.Parameter(torch.zeros(
- 1, self.num_patches + 1, self.embed_dim // self.summa_dim))
+ (1, 1, self.embed_dim // (self.summa_dim ** 2)),
+ device=get_current_device()))
+ self.pos_embed = nn.Parameter(torch.empty(
+ (1, self.num_patches + 1, self.embed_dim // (self.summa_dim ** 2)),
+ device=get_current_device()))
+ with seed(ParallelMode.TENSOR):
+ nn.init.trunc_normal_(self.pos_embed, std=.02)
- # move to cuda before broadcast
- self.to(get_current_device())
-
- # sync param in both forward and backward
- _cls_token = self.cls_token.view(-1)
- _pos_embed = self.pos_embed.view(-1)
- self._param = torch.cat([_cls_token, _pos_embed], dim=0)
-
- self._broadcast_params(self._param)
- self._param.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
- set_tensor_parallel_attribute(self.cls_token)
- set_tensor_parallel_attribute(self.pos_embed)
-
- def _broadcast_params(self, param) -> None:
- " broadcast to all column ranks for data consistency "
- ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
- col_group = gpc.get_group(ParallelMode.PARALLEL_2D_COL)
- dist.broadcast(param, src=ranks_in_col[0],
- group=col_group)
-
- def _sync_grad_hook(self, grad) -> None:
- dist.all_reduce(grad, group=gpc.get_group(
- ParallelMode.PARALLEL_2D_COL))
- grad = grad / self.summa_dim
- return grad
+ num_partition = gpc.get_world_size(ParallelMode.TENSOR)
+ set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
+ set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
- cls_token = self.cls_token.expand(x.shape[0], -1, -1)
+ cls_token = AllGatherLast.apply(
+ self.cls_token, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
+ cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
+
+ pos_embed = AllGatherLast.apply(
+ self.pos_embed, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
+ x = x + pos_embed
with seed(ParallelMode.TENSOR):
- x = self.pos_drop(x + self.pos_embed)
+ x = self.pos_drop(x)
return x
-
-
-@LAYERS.register_module
-class ViTInputSplitter2D(ParallelLayer):
- """Split the input tensor for 2D parallel Vision Transformer
- """
-
- def __init__(self):
- super().__init__()
- assert_summa_initialization()
- self.summa_dim = get_summa_dim_from_env()
-
- def forward(self, x: Tensor) -> Tensor:
- batch_size = x.size(0)
- return _ViT_Split_Input_2D.apply(
- x,
- batch_size,
- self.summa_dim,
- ParallelMode.PARALLEL_2D_COL
- )
diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py
index 570cf1c25..f29354356 100644
--- a/colossalai/nn/layer/parallel_2d/layers.py
+++ b/colossalai/nn/layer/parallel_2d/layers.py
@@ -11,7 +11,7 @@ from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D
from ._utils import get_summa_dim_from_env, assert_summa_initialization
-from .._common_utils import divide, set_tensor_parallel_attribute
+from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer
@@ -36,8 +36,9 @@ class Linear2D(ParallelLayer):
out_features: int,
bias: bool = True,
dtype=None,
- skip_bias_add: bool = False
- ):
+ skip_bias_add: bool = False,
+ init_weight='torch',
+ init_bias='torch'):
super().__init__()
self.in_features = in_features
@@ -72,31 +73,45 @@ class Linear2D(ParallelLayer):
self.register_parameter('bias', None)
# initialize parameters
- self.reset_parameters()
+ with seed(ParallelMode.TENSOR):
+ self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
- set_tensor_parallel_attribute(self.weight)
+ num_partition = gpc.get_world_size(ParallelMode.TENSOR)
+ set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
- set_tensor_parallel_attribute(self.bias)
+ set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
- def reset_parameters(self) -> None:
+ def reset_parameters(self, init_weight, init_bias) -> None:
+ assert init_weight in ('torch', 'jax', 'zero')
+ assert init_bias in ('torch', 'jax', 'zero')
# setting
- fan_in = self.in_features
- a = math.sqrt(5)
- nonlinearity = 'leaky_relu'
+ fan_in, fan_out = self.in_features, self.out_features
# init weight
- std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
- bound = math.sqrt(3.0) * std
- with seed(ParallelMode.TENSOR):
+ if init_weight == 'torch':
+ a = math.sqrt(5)
+ nonlinearity = 'leaky_relu'
+ std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
+ bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
+ elif init_weight == 'jax':
+ std = math.sqrt(2.0 / float(fan_in + fan_out))
+ a = math.sqrt(3.0) * std
+ init.uniform_(self.weight, -a, a)
+ elif init_weight == 'zero':
+ init.zeros_(self.weight)
# init bias
if self.bias is not None:
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
- with seed(ParallelMode.TENSOR):
+ if init_bias == 'torch':
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
+ elif init_bias == 'jax':
+ init.normal_(self.bias, std=1e-6)
+ elif init_bias == 'zero':
+ init.zeros_(self.bias)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
@@ -192,28 +207,19 @@ class LayerNorm2D(ParallelLayer):
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
- if self.row_rank == 0:
- self.gamma = Parameter(torch.ones(
- self.partitioned_partition,
- **factory_kwargs))
- self.beta = Parameter(torch.zeros(
- self.partitioned_partition,
- **factory_kwargs))
- else:
- self.gamma = Parameter(torch.tensor(
- 1.0,
- requires_grad=True,
- **factory_kwargs))
- self.beta = Parameter(torch.tensor(
- 1.0,
- requires_grad=True,
- **factory_kwargs))
+ self.gamma = Parameter(torch.ones(
+ self.partitioned_partition,
+ **factory_kwargs))
+ self.beta = Parameter(torch.zeros(
+ self.partitioned_partition,
+ **factory_kwargs))
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
- set_tensor_parallel_attribute(self.gamma)
- set_tensor_parallel_attribute(self.beta)
+ num_partition = gpc.get_world_size(ParallelMode.TENSOR)
+ set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
+ set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
diff --git a/colossalai/nn/layer/parallel_2p5d/__init__.py b/colossalai/nn/layer/parallel_2p5d/__init__.py
index b4ebc12ea..ab91862db 100644
--- a/colossalai/nn/layer/parallel_2p5d/__init__.py
+++ b/colossalai/nn/layer/parallel_2p5d/__init__.py
@@ -1,11 +1,10 @@
-from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Sum_2p5D, Add_Bias_2p5D
+from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Add_Bias_2p5D
from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D
-from ._vit import (ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D,
- ViTInputSplitter2p5D)
+from ._vit import ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D, ViTInputSplitter2p5D
from .layers import Linear2p5D, LayerNorm2p5D
__all__ = [
- 'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Sum_2p5D', 'Add_Bias_2p5D',
+ 'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Add_Bias_2p5D',
'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D',
'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D',
'ViTInputSplitter2p5D',
diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py
index db50b44fb..a8970963b 100644
--- a/colossalai/nn/layer/parallel_2p5d/_operation.py
+++ b/colossalai/nn/layer/parallel_2p5d/_operation.py
@@ -6,7 +6,8 @@ from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device, empty_cache
+from colossalai.utils import get_current_device
+from torch.cuda.amp import custom_bwd, custom_fwd
def get_parallel_group(parallel_mode: ParallelMode):
@@ -26,18 +27,17 @@ class Matmul_AB_2p5D(torch.autograd.Function):
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
- tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
- dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
@@ -49,41 +49,43 @@ class Matmul_AB_2p5D(torch.autograd.Function):
assert A.shape[-1] == B.shape[-2], \
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
- empty_cache()
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
- A = A.reshape((-1, A_shape[-1]))
+ A = A.reshape((-1, A_shape[-1])).contiguous()
B_shape = B.shape
- B = B.reshape((-1, B_shape[-1]))
+ B = B.reshape((-1, B_shape[-1])).contiguous()
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
- for i in range(tesseract_dim):
- A_temp = A.clone()
- B_temp = B.clone()
- src_a = i + row_rank * tesseract_dim + dep_rank * (
- tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
- dist.broadcast(A_temp, src=src_a,
- group=get_parallel_group(row_parallel_mode))
- src_b = col_rank + i * tesseract_dim + dep_rank * (
- tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
- dist.broadcast(B_temp, src=src_b,
- group=get_parallel_group(col_parallel_mode))
- torch.addmm(C, A_temp, B_temp, out=C)
+ A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
+ B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
+ A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
+ B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
+ op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
+ op_a.wait()
+ op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
+ for op in [op_a, op_b]:
+ op.wait()
+ for i in range(tesseract_dim):
+ src_a = i + tesseract_dim * row_rank
+ src_b = i + tesseract_dim * col_rank
+ src_a = src_a % tesseract_dim
+ src_b = src_b % tesseract_dim
+ A_temp = A_list[src_a]
+ B_temp = B_list[src_b]
+ torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape)
if ctx:
ctx.tesseract_dim = tesseract_dim
- ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
- ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
@@ -94,34 +96,32 @@ class Matmul_AB_2p5D(torch.autograd.Function):
return out
@staticmethod
+ @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
- A_grad = Matmul_ABT_2p5D.forward(
- None,
- output_grad, B,
- ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
- ctx.row_rank, ctx.col_rank, ctx.dep_rank,
- ctx.row_parallel_mode,
- ctx.col_parallel_mode,
- ctx.dep_parallel_mode,
- ctx.data_parallel_rank,
- ctx.pipeline_parallel_rank,
- ctx.pipeline_parallel_size,
- ctx.tensor_parallel_size
- )
- B_grad = Matmul_ATB_2p5D.forward(
- None,
- A, output_grad,
- ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
- ctx.row_rank, ctx.col_rank, ctx.dep_rank,
- ctx.row_parallel_mode,
- ctx.col_parallel_mode,
- ctx.dep_parallel_mode,
- ctx.data_parallel_rank,
- ctx.pipeline_parallel_rank,
- ctx.pipeline_parallel_size,
- ctx.tensor_parallel_size
- )
+ with torch.no_grad():
+ A_grad = Matmul_ABT_2p5D.apply(
+ output_grad, B,
+ ctx.tesseract_dim, ctx.A_shape,
+ ctx.row_rank, ctx.col_rank, ctx.dep_rank,
+ ctx.row_parallel_mode,
+ ctx.col_parallel_mode,
+ ctx.data_parallel_rank,
+ ctx.pipeline_parallel_rank,
+ ctx.pipeline_parallel_size,
+ ctx.tensor_parallel_size
+ )
+ B_grad = Matmul_ATB_2p5D.apply(
+ A, output_grad,
+ ctx.tesseract_dim, ctx.B_shape,
+ ctx.row_rank, ctx.col_rank, ctx.dep_rank,
+ ctx.row_parallel_mode,
+ ctx.col_parallel_mode,
+ ctx.data_parallel_rank,
+ ctx.pipeline_parallel_rank,
+ ctx.pipeline_parallel_size,
+ ctx.tensor_parallel_size
+ )
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@@ -130,18 +130,17 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
- tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
- dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
@@ -151,7 +150,6 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
- empty_cache()
if ctx:
ctx.save_for_backward(A, B)
@@ -180,13 +178,11 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
if ctx:
ctx.tesseract_dim = tesseract_dim
- ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
- ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
@@ -197,34 +193,32 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
return out
@staticmethod
+ @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
- A_grad = Matmul_AB_2p5D.forward(
- None,
- output_grad, B,
- ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
- ctx.row_rank, ctx.col_rank, ctx.dep_rank,
- ctx.row_parallel_mode,
- ctx.col_parallel_mode,
- ctx.dep_parallel_mode,
- ctx.data_parallel_rank,
- ctx.pipeline_parallel_rank,
- ctx.pipeline_parallel_size,
- ctx.tensor_parallel_size
- )
- B_grad = Matmul_ATB_2p5D.forward(
- None,
- output_grad, A,
- ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
- ctx.row_rank, ctx.col_rank, ctx.dep_rank,
- ctx.row_parallel_mode,
- ctx.col_parallel_mode,
- ctx.dep_parallel_mode,
- ctx.data_parallel_rank,
- ctx.pipeline_parallel_rank,
- ctx.pipeline_parallel_size,
- ctx.tensor_parallel_size
- )
+ with torch.no_grad():
+ A_grad = Matmul_AB_2p5D.apply(
+ output_grad, B,
+ ctx.tesseract_dim, ctx.A_shape,
+ ctx.row_rank, ctx.col_rank, ctx.dep_rank,
+ ctx.row_parallel_mode,
+ ctx.col_parallel_mode,
+ ctx.data_parallel_rank,
+ ctx.pipeline_parallel_rank,
+ ctx.pipeline_parallel_size,
+ ctx.tensor_parallel_size
+ )
+ B_grad = Matmul_ATB_2p5D.apply(
+ output_grad, A,
+ ctx.tesseract_dim, ctx.B_shape,
+ ctx.row_rank, ctx.col_rank, ctx.dep_rank,
+ ctx.row_parallel_mode,
+ ctx.col_parallel_mode,
+ ctx.data_parallel_rank,
+ ctx.pipeline_parallel_rank,
+ ctx.pipeline_parallel_size,
+ ctx.tensor_parallel_size
+ )
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@@ -233,18 +227,17 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
- tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
- dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
@@ -253,7 +246,6 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
assert A.shape[-2] == B.shape[-2], \
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
- empty_cache()
if ctx:
ctx.save_for_backward(A, B)
@@ -284,13 +276,11 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
if ctx:
ctx.tesseract_dim = tesseract_dim
- ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
- ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
@@ -301,34 +291,32 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
return out
@staticmethod
+ @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
- A_grad = Matmul_ABT_2p5D.forward(
- None,
- B, output_grad,
- ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
- ctx.row_rank, ctx.col_rank, ctx.dep_rank,
- ctx.row_parallel_mode,
- ctx.col_parallel_mode,
- ctx.dep_parallel_mode,
- ctx.data_parallel_rank,
- ctx.pipeline_parallel_rank,
- ctx.pipeline_parallel_size,
- ctx.tensor_parallel_size
- )
- B_grad = Matmul_AB_2p5D.forward(
- None,
- A, output_grad,
- ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
- ctx.row_rank, ctx.col_rank, ctx.dep_rank,
- ctx.row_parallel_mode,
- ctx.col_parallel_mode,
- ctx.dep_parallel_mode,
- ctx.data_parallel_rank,
- ctx.pipeline_parallel_rank,
- ctx.pipeline_parallel_size,
- ctx.tensor_parallel_size
- )
+ with torch.no_grad():
+ A_grad = Matmul_ABT_2p5D.apply(
+ B, output_grad,
+ ctx.tesseract_dim, ctx.A_shape,
+ ctx.row_rank, ctx.col_rank, ctx.dep_rank,
+ ctx.row_parallel_mode,
+ ctx.col_parallel_mode,
+ ctx.data_parallel_rank,
+ ctx.pipeline_parallel_rank,
+ ctx.pipeline_parallel_size,
+ ctx.tensor_parallel_size
+ )
+ B_grad = Matmul_AB_2p5D.apply(
+ A, output_grad,
+ ctx.tesseract_dim, ctx.B_shape,
+ ctx.row_rank, ctx.col_rank, ctx.dep_rank,
+ ctx.row_parallel_mode,
+ ctx.col_parallel_mode,
+ ctx.data_parallel_rank,
+ ctx.pipeline_parallel_rank,
+ ctx.pipeline_parallel_size,
+ ctx.tensor_parallel_size
+ )
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@@ -337,18 +325,16 @@ class Add_Bias_2p5D(torch.autograd.Function):
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
tesseract_dim: int,
- tesseract_dep: int,
row_rank: int,
col_rank: int,
dep_rank: int,
- row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
- dep_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
@@ -371,10 +357,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.tesseract_dim = tesseract_dim
- ctx.tesseract_dep = tesseract_dep
- ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
- ctx.dep_parallel_mode = dep_parallel_mode
ctx.bias = skip_bias_add
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
@@ -388,15 +371,13 @@ class Add_Bias_2p5D(torch.autograd.Function):
return output
@staticmethod
- def backward(ctx, output_grad):
+ @custom_bwd
+ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_rank = ctx.row_rank
col_rank = ctx.col_rank
dep_rank = ctx.dep_rank
tesseract_dim = ctx.tesseract_dim
- tesseract_dep = ctx.tesseract_dep
- row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
- dep_parallel_mode = ctx.dep_parallel_mode
data_parallel_rank = ctx.data_parallel_rank
pipeline_parallel_rank = ctx.pipeline_parallel_rank
pipeline_parallel_size = ctx.pipeline_parallel_size
@@ -428,29 +409,25 @@ class Add_Bias_2p5D(torch.autograd.Function):
class _LayerNorm_2p5D(torch.autograd.Function):
@staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
- row_parallel_mode: ParallelMode,
- col_parallel_mode: ParallelMode,
- dep_parallel_mode: ParallelMode) -> Tensor:
+ row_parallel_mode: ParallelMode) -> Tensor:
input = input - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx.hidden_size = hidden_size
output = input * Var_x
ctx.save_for_backward(output, Var_x)
ctx.row_parallel_mode = row_parallel_mode
- ctx.col_parallel_mode = col_parallel_mode
- ctx.dep_parallel_mode = dep_parallel_mode
return output
@staticmethod
+ @custom_bwd
def backward(ctx, output_grad):
row_parallel_mode = ctx.row_parallel_mode
- col_parallel_mode = ctx.col_parallel_mode
- dep_parallel_mode = ctx.dep_parallel_mode
x, Var_x = ctx.saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
with torch.no_grad():
@@ -473,63 +450,122 @@ class _LayerNorm_2p5D(torch.autograd.Function):
return input_grad, None, None, None, None, None, None
-class Sum_2p5D(torch.autograd.Function):
- """Compute the sum of input tensors
- """
+# class Sum_2p5D(torch.autograd.Function):
+# """Compute the sum of input tensors
+# """
+
+# @staticmethod
+# def forward(ctx,
+# inputs,
+# dim,
+# tesseract_dim,
+# row_parallel_mode,
+# keepdim=False):
+# # input: [b/q, s, h/q]
+# ctx.save_for_backward(inputs)
+# # sum: [b/q, s]
+# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
+# torch.distributed.all_reduce(
+# out, group=gpc.get_group(row_parallel_mode))
+# return out
+
+# @staticmethod
+# def backward(ctx, output_grad):
+# with torch.no_grad():
+# inputs = ctx.saved_tensors
+# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
+# return input_grad, None, None, None, None, None
+
+
+# class _ViT_Split_2p5D(torch.autograd.Function):
+# @staticmethod
+# @custom_fwd(cast_inputs=torch.float16)
+# def forward(ctx, inputs, batch_size,
+# tesseract_dim, tesseract_dep,
+# xz_parallel_mode):
+# # inputs: [b, s, h/q]
+# # output: [b/dq, s, h/q]
+
+# ctx.BATCH_SIZE = batch_size
+# ctx.tesseract_dim = tesseract_dim
+# ctx.tesseract_dep = tesseract_dep
+# ctx.xz_parallel_mode = xz_parallel_mode
+# xz_rank = gpc.get_local_rank(xz_parallel_mode)
+# output = torch.chunk(inputs, tesseract_dep *
+# tesseract_dim, dim=0)[xz_rank]
+# output = output.clone()
+# return output
+
+# @staticmethod
+# @custom_bwd
+# def backward(ctx, output_grad):
+# # output_grad: [b/dq, s, h/q]
+# # grads: [b, s, h/q]
+# # *
+# grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
+# grads = torch.empty(grads_shape,
+# dtype=output_grad.dtype,
+# device=get_current_device())
+# dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
+# output_grad.contiguous(),
+# group=get_parallel_group(ctx.xz_parallel_mode))
+# return grads, None, None, None, None
+
+class AllGatherLast(torch.autograd.Function):
@staticmethod
- def forward(ctx,
- inputs,
- dim,
- tesseract_dim,
- row_parallel_mode,
- keepdim=False):
- # input: [b/q, s, h/q]
- empty_cache()
- ctx.save_for_backward(inputs)
- # sum: [b/q, s]
- out = torch.sum(inputs, dim=dim, keepdim=keepdim)
- torch.distributed.all_reduce(
- out, group=gpc.get_group(row_parallel_mode))
- return out
-
- @staticmethod
- def backward(ctx, output_grad):
- with torch.no_grad():
- inputs = ctx.saved_tensors
- input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
- return input_grad, None, None, None, None, None
-
-
-class _ViT_Split_2p5D(torch.autograd.Function):
- @staticmethod
- def forward(ctx, inputs, batch_size,
- tesseract_dim, tesseract_dep,
- xz_parallel_mode):
- # inputs: [b, s, h/q]
- # output: [b/dq, s, h/q]
- empty_cache()
-
- ctx.batch_size = batch_size
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx: Any,
+ inputs: Tensor,
+ tesseract_dim: int,
+ col_parallel_mode: ParallelMode) -> Tensor:
ctx.tesseract_dim = tesseract_dim
- ctx.tesseract_dep = tesseract_dep
- ctx.xz_parallel_mode = xz_parallel_mode
- xz_rank = gpc.get_local_rank(xz_parallel_mode)
- output = torch.chunk(inputs, tesseract_dep *
- tesseract_dim, dim=0)[xz_rank]
- output = output.clone()
- return output
+ ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
+
+ last_dim = tesseract_dim * inputs.size(-1)
+ outputs_shape = (last_dim,) + inputs.shape[:-1]
+ outputs = torch.empty(
+ outputs_shape, dtype=inputs.dtype, device=get_current_device())
+ dist.all_gather(
+ list(outputs.chunk(tesseract_dim, dim=0)),
+ inputs.permute(2, 0, 1).contiguous(),
+ group=gpc.get_group(col_parallel_mode)
+ )
+ outputs = outputs.permute(1, 2, 0).contiguous()
+ return outputs
@staticmethod
- def backward(ctx, output_grad):
- # output_grad: [b/dq, s, h/q]
- # grads: [b, s, h/q]
- # *
- grads_shape = (ctx.batch_size,) + output_grad.shape[1:]
- grads = torch.empty(grads_shape,
- dtype=output_grad.dtype,
- device=get_current_device())
- dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
- output_grad.contiguous(),
- group=get_parallel_group(ctx.xz_parallel_mode))
- return grads, None, None, None, None
+ @custom_bwd
+ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
+ grad = output_grad.chunk(ctx.tesseract_dim, dim=-1)[ctx.row_rank]
+ return grad.contiguous(), None, None
+
+
+class SplitFirst(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx: Any,
+ inputs: Tensor,
+ tesseract_dim: int,
+ col_parallel_mode: ParallelMode) -> Tensor:
+ ctx.tesseract_dim = tesseract_dim
+ ctx.batch_size = inputs.size(0)
+ ctx.para_mode = col_parallel_mode
+ row_rank = gpc.get_local_rank(col_parallel_mode)
+
+ outputs = inputs.chunk(tesseract_dim, dim=0)[row_rank]
+ return outputs
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
+ grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
+ grad = torch.empty(
+ grad_shape, dtype=output_grad.dtype, device=get_current_device())
+ dist.all_gather(
+ list(grad.chunk(ctx.tesseract_dim, dim=0)),
+ output_grad.contiguous(),
+ group=gpc.get_group(ctx.para_mode)
+ )
+ return grad, None, None
\ No newline at end of file
diff --git a/colossalai/nn/layer/parallel_2p5d/_transformer.py b/colossalai/nn/layer/parallel_2p5d/_transformer.py
index c13ef87b4..ed469ba7d 100644
--- a/colossalai/nn/layer/parallel_2p5d/_transformer.py
+++ b/colossalai/nn/layer/parallel_2p5d/_transformer.py
@@ -12,10 +12,11 @@ from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from .layers import Linear2p5D, LayerNorm2p5D
from .._common_utils import ACT2FN
+from ..base_layer import ParallelLayer
@LAYERS.register_module
-class TransformerMLP2p5D(nn.Module):
+class TransformerMLP2p5D(ParallelLayer):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
@@ -36,21 +37,24 @@ class TransformerMLP2p5D(nn.Module):
def __init__(self,
in_features: int,
- mlp_ratio: int,
+ mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
+ skip_bias_add: bool = False
):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.in_features = in_features
+ self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear2p5D(
in_features,
- mlp_ratio * in_features,
- dtype=dtype
+ int(mlp_ratio * in_features),
+ dtype=dtype,
+ skip_bias_add=skip_bias_add
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
@@ -59,24 +63,34 @@ class TransformerMLP2p5D(nn.Module):
# Project back to h.
self.dense_2 = Linear2p5D(
- mlp_ratio * in_features,
+ int(mlp_ratio * in_features),
in_features,
- dtype=dtype
+ dtype=dtype,
+ skip_bias_add=skip_bias_add
)
self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2p5D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
- intermediate_output = self.dense_1(x)
+ if self.skip_bias_add:
+ intermediate_output, _ = self.dense_1(x)
+ else:
+ intermediate_output = self.dense_1(x)
+
intermediate_output = self.activation_func(intermediate_output)
- output = self.dense_2(intermediate_output)
+
+ if self.skip_bias_add:
+ output, _ = self.dense_2(intermediate_output)
+ else:
+ output = self.dense_2(intermediate_output)
+
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
-class TransformerSelfAttention2p5D(nn.Module):
+class TransformerSelfAttention2p5D(ParallelLayer):
"""Self attention layer for 2.5D parallel Transformer
:param hidden_size: hidden size
@@ -92,10 +106,10 @@ class TransformerSelfAttention2p5D(nn.Module):
"""
def __init__(self,
- hidden_size,
- num_attention_heads,
- attention_dropout_prob,
- hidden_dropout_prob,
+ hidden_size: int,
+ num_attention_heads: int,
+ attention_dropout_prob: float,
+ hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
@@ -127,7 +141,7 @@ class TransformerSelfAttention2p5D(nn.Module):
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
- (self.num_attention_heads, 3 * self.attention_head_size)
+ (self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
@@ -136,7 +150,7 @@ class TransformerSelfAttention2p5D(nn.Module):
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
- math.sqrt(self.attention_head_size)
+ math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
@@ -144,7 +158,7 @@ class TransformerSelfAttention2p5D(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
- :-2] + (self.all_head_size,)
+ :-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
@@ -155,7 +169,7 @@ class TransformerSelfAttention2p5D(nn.Module):
@LAYERS.register_module
-class TransformerLayer2p5D(nn.Module):
+class TransformerLayer2p5D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
@@ -175,10 +189,10 @@ class TransformerLayer2p5D(nn.Module):
"""
def __init__(self,
- hidden_size,
- num_attention_heads,
- act_func='gelu',
- mlp_ratio=4,
+ hidden_size: int,
+ num_attention_heads: int,
+ act_func: str = 'gelu',
+ mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
diff --git a/colossalai/nn/layer/parallel_2p5d/_vit.py b/colossalai/nn/layer/parallel_2p5d/_vit.py
index 4e992ac34..180e27b3e 100644
--- a/colossalai/nn/layer/parallel_2p5d/_vit.py
+++ b/colossalai/nn/layer/parallel_2p5d/_vit.py
@@ -5,22 +5,25 @@ import math
import torch
from torch import nn as nn, Tensor, distributed as dist
+from torch.nn.init import _calculate_fan_in_and_fan_out
-from colossalai.context.parallel_mode import ParallelMode
+from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS
+from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
-from ._operation import _ViT_Split_2p5D
+from ._operation import AllGatherLast, SplitFirst
from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from .layers import Linear2p5D
-from .._common_utils import ACT2FN, divide, CheckpointModule
-from .._common_utils import set_tensor_parallel_attribute
+from ..base_layer import ParallelLayer
+from ..fused_bias_gelu import bias_gelu_impl
+from .._common_utils import (ACT2FN, divide, to_2tuple,
+ set_tensor_parallel_attribute_by_partition)
@LAYERS.register_module
-class ViTMLP2p5D(CheckpointModule):
+class ViTMLP2p5D(ParallelLayer):
"""MLP layer for 2.5D parallel Vision Transformer
:param in_features: size of each input sample
@@ -43,19 +46,32 @@ class ViTMLP2p5D(CheckpointModule):
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
- checkpoint: bool = False
+ checkpoint: bool = False,
+ weight_init='torch'
):
- super().__init__(checkpoint=checkpoint)
+ super().__init__()
assert_tesseract_initialization()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
+ self.checkpoint = checkpoint
+ assert weight_init in ('torch', 'jax')
+
+ if act_func == 'fused_gelu':
+ self.act = bias_gelu_impl
+ skip_dense_1_add_bias = True
+ else:
+ self.act = ACT2FN[act_func]
+ skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear2p5D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
+ init_weight=weight_init,
+ init_bias=weight_init,
+ skip_bias_add=skip_dense_1_add_bias
)
self.act = ACT2FN[act_func]
@@ -65,20 +81,39 @@ class ViTMLP2p5D(CheckpointModule):
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
+ init_weight=weight_init,
+ init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
- intermediate_output = self.dense_1(hidden_states)
- intermediate_output = self.act(intermediate_output)
- intermediate_output = self.dropout(intermediate_output)
+ if self.act == bias_gelu_impl:
+ intermediate_output, bias = self.dense_1(hidden_states)
+ intermediate_output = self.act(intermediate_output, bias)
+ else:
+ intermediate_output = self.dense_1(hidden_states)
+ intermediate_output = self.act(intermediate_output)
+
+ with seed(ParallelMode.TENSOR):
+ intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
- output = self.dropout(output)
+
+ with seed(ParallelMode.TENSOR):
+ output = self.dropout(output)
return output
+ def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
+ return checkpoint(self._forward, hidden_states)
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ if self.checkpoint:
+ return self._checkpoint_forward(hidden_states)
+ else:
+ return self._forward(hidden_states)
+
@LAYERS.register_module
-class ViTSelfAttention2p5D(CheckpointModule):
+class ViTSelfAttention2p5D(ParallelLayer):
"""Self-attention layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
@@ -101,9 +136,10 @@ class ViTSelfAttention2p5D(CheckpointModule):
attention_dropout_prob,
hidden_dropout_prob,
dtype=None,
- checkpoint: bool = False
+ checkpoint: bool = False,
+ weight_init='torch'
):
- super().__init__(checkpoint=checkpoint)
+ super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
@@ -112,19 +148,30 @@ class ViTSelfAttention2p5D(CheckpointModule):
num_attention_heads, self.tesseract_dim) # *
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.checkpoint = checkpoint
+ assert weight_init in ('torch', 'jax')
+ if weight_init == 'jax':
+ self.init_bias = 'zero'
+ else:
+ self.init_bias = weight_init
self.query_key_value = Linear2p5D(
hidden_size,
3 * hidden_size,
dtype=dtype,
+ init_weight=weight_init,
+ init_bias=self.init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2p5D(
hidden_size,
hidden_size,
dtype=dtype,
+ init_weight=weight_init,
+ init_bias=self.init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
+ self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
@@ -140,8 +187,10 @@ class ViTSelfAttention2p5D(CheckpointModule):
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
- attention_probs = nn.Softmax(dim=-1)(attention_scores)
- attention_probs = self.attention_dropout(attention_probs)
+ attention_probs = self.softmax(attention_scores)
+
+ with seed(ParallelMode.TENSOR):
+ attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
@@ -150,12 +199,22 @@ class ViTSelfAttention2p5D(CheckpointModule):
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
- output = self.dropout(output)
+ with seed(ParallelMode.TENSOR):
+ output = self.dropout(output)
return output
+ def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
+ return checkpoint(self._forward, hidden_states)
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ if self.checkpoint:
+ return self._checkpoint_forward(hidden_states)
+ else:
+ return self._forward(hidden_states)
+
@LAYERS.register_module
-class ViTHead2p5D(nn.Module):
+class ViTHead2p5D(ParallelLayer):
"""Output layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
@@ -170,13 +229,24 @@ class ViTHead2p5D(nn.Module):
hidden_size,
num_classes,
dtype=None,
+ weight_init='torch'
):
super().__init__()
assert_tesseract_initialization()
+ assert weight_init in ('torch', 'jax')
+ if weight_init == 'jax':
+ self.init_weight = 'zero'
+ self.init_bias = 'zero'
+ else:
+ self.init_weight = weight_init
+ self.init_bias = weight_init
+
self.linear = Linear2p5D(
hidden_size,
num_classes,
dtype=dtype,
+ init_weight=self.init_weight,
+ init_bias=self.init_bias
)
def forward(self, x: Tensor) -> Tensor:
@@ -186,7 +256,7 @@ class ViTHead2p5D(nn.Module):
@LAYERS.register_module
-class ViTPatchEmbedding2p5D(nn.Module):
+class ViTPatchEmbedding2p5D(ParallelLayer):
""" 2.5D Image to Patch Embedding
:param img_size: iamge size
@@ -206,7 +276,8 @@ class ViTPatchEmbedding2p5D(nn.Module):
patch_size,
embed_dim,
in_chans=3,
- flatten=True):
+ flatten=True,
+ weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
@@ -219,34 +290,28 @@ class ViTPatchEmbedding2p5D(nn.Module):
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
- self.embed_dim = embed_dim // self.tesseract_dim # *
+ self.embed_dim = embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2) # *
- self.proj = nn.Conv2d(in_chans,
- self.embed_dim,
- kernel_size=patch_size,
- stride=patch_size,
- )
+ with seed(ParallelMode.TENSOR):
+ self.proj = nn.Conv2d(in_chans,
+ self.embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ device=get_current_device()
+ )
+ self._set_tensor_parallel_attribute()
- # move self to cuda before sync
- self.to(get_current_device())
+ if weight_init == 'jax':
+ with seed(ParallelMode.TENSOR):
+ fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
+ std = math.sqrt(1.0 / fan_in)
+ nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
+ nn.init.zeros_(self.proj.bias)
- # sync
- self._broadcast_conv_params()
- self.proj.weight.register_hook(self._sync_grad_during_backward)
- self.proj.bias.register_hook(self._sync_grad_during_backward)
-
- def _broadcast_conv_params(self) -> None:
- xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
- dist.broadcast(self.proj.weight, src=xz_rank[0],
- group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
- dist.broadcast(self.proj.bias, src=xz_rank[0],
- group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
-
- def _sync_grad_during_backward(self, grad: Tensor) -> None:
- dist.all_reduce(grad, group=gpc.get_group(
- ParallelMode.PARALLEL_2P5D_XZ))
- grad = grad / self.tesseract_dim / self.tesseract_dep # *
- return grad
+ def _set_tensor_parallel_attribute(self):
+ num_partition = gpc.get_world_size(ParallelMode.TENSOR)
+ set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
+ set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
@@ -259,7 +324,25 @@ class ViTPatchEmbedding2p5D(nn.Module):
@LAYERS.register_module
-class ViTTokenFuser2p5D(nn.Module):
+class ViTInputSplitter2p5D(ParallelLayer):
+ """Split the input tensor for 2D parallel Vision Transformer
+ """
+
+ def __init__(self):
+ super().__init__()
+ assert_tesseract_initialization()
+ self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = AllGatherLast.apply(
+ x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
+ x = SplitFirst.apply(
+ x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
+ return x
+
+
+@LAYERS.register_module
+class ViTTokenFuser2p5D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
@@ -293,59 +376,46 @@ class ViTTokenFuser2p5D(nn.Module):
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
- 1, 1, self.embed_dim // self.tesseract_dim)) # *
- self.pos_embed = nn.Parameter(torch.zeros(
- 1, self.num_patches + 1, self.embed_dim // self.tesseract_dim)) # *
+ (1, 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
+ device=get_current_device()))
+ self.pos_embed = nn.Parameter(torch.empty(
+ (1, self.num_patches + 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
+ device=get_current_device()))
+ with seed(ParallelMode.TENSOR):
+ nn.init.trunc_normal_(self.pos_embed, std=.02)
- # move to cuda before broadcast
- self.to(get_current_device())
-
- self._broadcast_params()
- self.cls_token.register_hook(self._sync_grad_hook)
- self.pos_embed.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
- set_tensor_parallel_attribute(self.cls_token)
- set_tensor_parallel_attribute(self.pos_embed)
+ num_partition = gpc.get_world_size(ParallelMode.TENSOR)
+ set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
+ set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
- def _broadcast_params(self) -> None:
+ def _broadcast_params(self, param) -> None:
" broadcast to all column ranks for data consistency "
- xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
- dist.broadcast(self.cls_token, src=xz_rank[0],
- group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
- dist.broadcast(self.pos_embed, src=xz_rank[0],
- group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
+ if self.tesseract_dep > 1:
+ xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
+ xz_group = gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)
+ dist.broadcast(param, src=xz_rank[0],
+ group=xz_group)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2P5D_XZ))
- grad = grad / self.tesseract_dim / self.tesseract_dep # *
+ grad = grad / self.tesseract_dim # / self.tesseract_dep # *
return grad
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
- cls_token = self.cls_token.expand(x.shape[0], -1, -1)
+ cls_token = AllGatherLast.apply(
+ self.cls_token, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
+ cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
- x = self.pos_drop(x + self.pos_embed)
+
+ pos_embed = AllGatherLast.apply(
+ self.pos_embed, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
+ x = x + pos_embed
+ with seed(ParallelMode.TENSOR):
+ x = self.pos_drop(x)
return x
-
-
-@LAYERS.register_module
-class ViTInputSplitter2p5D(nn.Module):
-
- def __init__(self):
- super().__init__()
- assert_tesseract_initialization()
- self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
-
- def forward(self, x: Tensor) -> Tensor:
- batch_size = x.size(0)
- return _ViT_Split_2p5D.apply(
- x,
- batch_size,
- self.tesseract_dim,
- self.tesseract_dep,
- ParallelMode.PARALLEL_2P5D_XZ,
- )
diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py
index 63b0f2a5c..224fa615f 100644
--- a/colossalai/nn/layer/parallel_2p5d/layers.py
+++ b/colossalai/nn/layer/parallel_2p5d/layers.py
@@ -10,7 +10,7 @@ from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2p5D, Add_Bias_2p5D, _LayerNorm_2p5D
from ._utils import get_tesseract_dim_dep_from_env, assert_tesseract_initialization
-from .._common_utils import divide, set_tensor_parallel_attribute
+from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer
@@ -33,7 +33,9 @@ class Linear2p5D(ParallelLayer):
out_features: int,
bias: bool = True,
dtype=None,
- skip_bias_add: bool = False
+ skip_bias_add: bool = False,
+ init_weight='torch',
+ init_bias='torch'
):
super().__init__()
@@ -46,7 +48,7 @@ class Linear2p5D(ParallelLayer):
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
- self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
+ self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.input_size_per_partition = divide(in_features, self.tesseract_dim)
@@ -69,46 +71,59 @@ class Linear2p5D(ParallelLayer):
self.register_parameter('bias', None)
# initialize parameters
- self.reset_parameters()
+ with seed(ParallelMode.TENSOR):
+ self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
- set_tensor_parallel_attribute(self.weight)
+ num_partition = gpc.get_world_size(ParallelMode.TENSOR)
+ set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
- set_tensor_parallel_attribute(self.bias)
+ set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
- def reset_parameters(self) -> None:
+ def reset_parameters(self, init_weight, init_bias) -> None:
+ assert init_weight in ('torch', 'jax', 'zero')
+ assert init_bias in ('torch', 'jax', 'zero')
# setting
- fan_in = self.in_features
- a = math.sqrt(5)
- nonlinearity = 'leaky_relu'
+ fan_in, fan_out = self.in_features, self.out_features
# init weight
- std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
- bound = math.sqrt(3.0) * std
- with seed(ParallelMode.TENSOR):
+ if init_weight == 'torch':
+ a = math.sqrt(5)
+ nonlinearity = 'leaky_relu'
+ std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
+ bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
+ elif init_weight == 'jax':
+ std = math.sqrt(2.0 / float(fan_in + fan_out))
+ a = math.sqrt(3.0) * std
+ init.uniform_(self.weight, -a, a)
+ elif init_weight == 'zero':
+ init.zeros_(self.weight)
# init bias
if self.bias is not None:
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
- with seed(ParallelMode.TENSOR):
+ if init_bias == 'torch':
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
+ elif init_bias == 'jax':
+ init.normal_(self.bias, std=1e-6)
+ elif init_bias == 'zero':
+ init.zeros_(self.bias)
def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
+
output = Matmul_AB_2p5D.apply(
x,
self.weight,
self.tesseract_dim,
- self.tesseract_dep,
out_shape,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
- ParallelMode.PARALLEL_2P5D_DEP,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
@@ -121,11 +136,9 @@ class Linear2p5D(ParallelLayer):
None,
self.bias,
self.hidden_size_per_partition,
- self.tesseract_dim, self.tesseract_dep,
+ self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
- ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
- ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
@@ -138,11 +151,9 @@ class Linear2p5D(ParallelLayer):
output,
self.bias,
self.hidden_size_per_partition,
- self.tesseract_dim, self.tesseract_dep,
+ self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
- ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
- ParallelMode.PARALLEL_2P5D_DEP,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
@@ -168,6 +179,7 @@ class LayerNorm2p5D(ParallelLayer):
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
+
def __init__(self,
normalized_shape: int,
eps: float = 1e-05,
@@ -184,7 +196,7 @@ class LayerNorm2p5D(ParallelLayer):
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
- self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
+ self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.partitioned_partition = divide(
@@ -193,27 +205,19 @@ class LayerNorm2p5D(ParallelLayer):
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
- if self.row_rank == 0:
- self.gamma = Parameter(torch.ones(
- self.partitioned_partition,
- **factory_kwargs))
- self.beta = Parameter(torch.zeros(
- self.partitioned_partition,
- **factory_kwargs))
- else:
- self.gamma = Parameter(torch.tensor(
- 1.0,
- requires_grad=True,
- **factory_kwargs))
- self.beta = Parameter(torch.tensor(
- 1.0,
- requires_grad=True,
- **factory_kwargs))
+ self.gamma = Parameter(torch.ones(
+ self.partitioned_partition,
+ **factory_kwargs))
+ self.beta = Parameter(torch.zeros(
+ self.partitioned_partition,
+ **factory_kwargs))
+
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
- set_tensor_parallel_attribute(self.gamma)
- set_tensor_parallel_attribute(self.beta)
+ num_partition = gpc.get_world_size(ParallelMode.TENSOR)
+ set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
+ set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
@@ -233,16 +237,12 @@ class LayerNorm2p5D(ParallelLayer):
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = _LayerNorm_2p5D.apply(x, E_x, Var_x, self.normalized_shape,
- ParallelMode.PARALLEL_2P5D_ROW,
- ParallelMode.PARALLEL_2P5D_COL,
- ParallelMode.PARALLEL_2P5D_DEP)
+ ParallelMode.PARALLEL_2P5D_ROW)
bias = Add_Bias_2p5D.apply(
None, self.beta, self.partitioned_partition,
- self.tesseract_dim, self.tesseract_dep,
+ self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
- ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
- ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
@@ -251,11 +251,9 @@ class LayerNorm2p5D(ParallelLayer):
)
scale = Add_Bias_2p5D.apply(
None, self.gamma, self.partitioned_partition,
- self.tesseract_dim, self.tesseract_dep,
+ self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
- ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
- ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py
index cb790fb51..f8287f932 100644
--- a/colossalai/nn/layer/parallel_3d/_operation.py
+++ b/colossalai/nn/layer/parallel_3d/_operation.py
@@ -1,21 +1,223 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-from typing import Any, Tuple
+from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
-from colossalai.communication import all_gather, reduce_scatter, scatter
+from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.utils import empty_cache, get_current_device
from torch import Tensor
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+
+class linear_3d(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx: Any,
+ input_: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor],
+ input_parallel_mode: ParallelMode,
+ weight_parallel_mode: ParallelMode,
+ output_parallel_mode: ParallelMode,
+ input_dim: int = 0,
+ weight_dim: int = -1,
+ output_dim: int = 0) -> Tensor:
+ assert input_.shape[-1] == weight.shape[0], \
+ 'Invalid shapes: input = {}, weight = {}.'.format(input_.shape, weight.shape)
+
+ ctx.use_bias = bias is not None
+
+ input_ = all_gather(input_, input_dim, input_parallel_mode)
+ input_ = torch.cat(input_, dim=input_dim)
+ # weight = all_gather(weight, weight_dim, weight_parallel_mode)
+ ctx.save_for_backward(input_, weight)
+
+ output = torch.matmul(input_, weight)
+ output = reduce_scatter(output, output_dim, output_parallel_mode)
+
+ if bias is not None:
+ # ranks_in_group = gpc.get_ranks_in_group(output_parallel_mode)
+ # src_rank = ranks_in_group[gpc.get_local_rank(input_parallel_mode)]
+ # dist.broadcast(bias,
+ # src=src_rank,
+ # group=gpc.get_group(output_parallel_mode))
+ # bias = all_gather(bias, -1, weight_parallel_mode)
+ output += bias
+ # ctx.src_rank = src_rank
+
+ # ctx.save_for_backward(input_, weight)
+ # output = torch.matmul(input_, weight)
+ # dist.all_reduce(output, group=gpc.get_group(output_parallel_mode))
+ # output += bias
+
+ ctx.input_parallel_mode = input_parallel_mode
+ ctx.weight_parallel_mode = weight_parallel_mode
+ ctx.output_parallel_mode = output_parallel_mode
+ ctx.input_dim = input_dim
+ ctx.weight_dim = weight_dim
+ ctx.output_dim = output_dim
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
+ input_, weight = ctx.saved_tensors
+ with torch.no_grad():
+ # input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
+ # dist.all_reduce(input_grad,
+ # group=gpc.get_group(ctx.input_parallel_mode))
+ # weight_grad = torch.matmul(
+ # input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
+ # output_grad.reshape(-1, output_grad.shape[-1]))
+ # dist.all_reduce(weight_grad,
+ # group=gpc.get_group(ctx.weight_parallel_mode))
+
+ # bias_grad = torch.sum(output_grad,
+ # dim=tuple(
+ # range(len(output_grad.shape))[:-1]))
+ # bias_grad = reduce_scatter(bias_grad, -1,
+ # ctx.weight_parallel_mode)
+ # dist.reduce(bias_grad,
+ # dst=ctx.src_rank,
+ # group=gpc.get_group(ctx.output_parallel_mode))
+ # if gpc.get_local_rank(
+ # ctx.output_parallel_mode) != gpc.get_local_rank(
+ # ctx.input_parallel_mode):
+ # bias_grad = None
+
+ # input_ = all_gather(input_, ctx.input_dim, ctx.input_parallel_mode)
+ # weight = all_gather(weight, ctx.weight_dim,
+ # ctx.weight_parallel_mode)
+
+ output_grad = all_gather(output_grad, ctx.output_dim,
+ ctx.output_parallel_mode)
+ output_grad = torch.cat(output_grad, dim=ctx.output_dim)
+
+ input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
+
+ input_grad, input_op = reduce_scatter(input_grad, ctx.input_dim,
+ ctx.input_parallel_mode,
+ async_op=True)
+ weight_grad = torch.matmul(
+ input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
+ output_grad.reshape(-1, output_grad.shape[-1]))
+
+ # weight_grad = torch.matmul(
+ # input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
+ # output_grad.reshape(-1, output_grad.shape[-1]))
+ # weight_grad = reduce_scatter(weight_grad, ctx.weight_dim,
+ # ctx.weight_parallel_mode)
+ if ctx.use_bias:
+ bias_grad = torch.sum(output_grad,
+ dim=tuple(
+ range(len(output_grad.shape))[:-1]))
+ # bias_grad =all_reduce(bias_grad, ctx.output_parallel_mode)
+ # dist.all_reduce(bias_grad,
+ # group=gpc.get_group(ctx.weight_parallel_mode))
+ weight_grad = torch.cat([weight_grad, torch.unsqueeze(bias_grad, dim=0)])
+
+ weight_grad, weight_op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
+
+ input_op.wait()
+ weight_op.wait()
+ if ctx.use_bias:
+ bias_grad = weight_grad[-1]
+ weight_grad = weight_grad[:-1]
+
+ return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
+
+
+class layer_norm_3d(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx: Any, input_: Tensor, weight: Tensor, bias: Tensor,
+ normalized_shape: int, eps: float,
+ input_parallel_mode: ParallelMode,
+ weight_parallel_mode: ParallelMode,
+ output_parallel_mode: ParallelMode) -> Tensor:
+ # mean = torch.sum(input_, dim=-1)
+ # dist.all_reduce(mean, group=gpc.get_group(output_parallel_mode))
+ # mean /= normalized_shape
+ # mu = input_ - mean
+ # var = torch.sum(torch.pow(mu, 2), dim=-1)
+ # dist.all_reduce(var, group=gpc.get_group(output_parallel_mode))
+ # var /= normalized_shape
+ # std_dev = torch.sqrt(var + eps)
+ # ctx.save_for_backward(input_, mu, std_dev, weight)
+
+ # output = weight * mu / std_dev + bias
+
+ mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True),
+ output_parallel_mode) / normalized_shape
+ mu = input_ - mean
+ var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True),
+ output_parallel_mode) / normalized_shape
+ sigma = torch.sqrt(var + eps)
+
+ # ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
+ # src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
+ # transforms = torch.stack([weight, bias]).contiguous()
+ # dist.broadcast(transforms,
+ # src=src_rank,
+ # group=gpc.get_group(input_parallel_mode))
+ # transforms = all_gather(transforms, -1, weight_parallel_mode)
+ # weight, bias = transforms[0], transforms[1]
+
+ ctx.save_for_backward(mu, sigma, weight)
+
+ z = mu / sigma
+ output = weight * z + bias
+
+ # ctx.src_rank = src_rank
+ ctx.normalized_shape = normalized_shape
+ ctx.input_parallel_mode = input_parallel_mode
+ ctx.weight_parallel_mode = weight_parallel_mode
+ ctx.output_parallel_mode = output_parallel_mode
+
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
+ mu, sigma, weight = ctx.saved_tensors
+ with torch.no_grad():
+ bias_grad, weight_grad = output_grad, output_grad * mu / sigma
+ grads = torch.stack([bias_grad, weight_grad]).contiguous()
+ grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1]))
+ grads = all_reduce(grads, ctx.weight_parallel_mode)
+ grads = all_reduce(grads, ctx.input_parallel_mode)
+ bias_grad, weight_grad = grads[0], grads[1]
+
+ # grads = reduce_scatter(grads, -1, ctx.weight_parallel_mode)
+ # dist.reduce(grads,
+ # dst=ctx.src_rank,
+ # group=gpc.get_group(ctx.input_parallel_mode))
+ # if gpc.get_local_rank(
+ # ctx.input_parallel_mode) == gpc.get_local_rank(
+ # ctx.output_parallel_mode):
+ # bias_grad, weight_grad = grads[0], grads[1]
+ # else:
+ # bias_grad, weight_grad = None, None
+
+ dz = output_grad * weight
+ dvar = dz * mu * (-0.5) * sigma**(-3)
+ dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
+ dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
+ dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
+
+ input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
+
+ return input_grad, weight_grad, bias_grad, None, None, None, None, None
class Matmul_AB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@@ -29,7 +231,6 @@ class Matmul_AB_3D(torch.autograd.Function):
# A: [m/q^2, n, k/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, h/q]
- empty_cache()
ctx.save_for_backward(A, B)
assert A.shape[-1] == B.shape[0], \
@@ -52,6 +253,7 @@ class Matmul_AB_3D(torch.autograd.Function):
return out
@staticmethod
+ @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
@@ -72,6 +274,7 @@ class Matmul_ABT_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@@ -85,7 +288,6 @@ class Matmul_ABT_3D(torch.autograd.Function):
# A: [m/q^2, n, h/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, k/q]
- empty_cache()
ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode)
@@ -105,6 +307,7 @@ class Matmul_ABT_3D(torch.autograd.Function):
return out
@staticmethod
+ @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
@@ -125,6 +328,7 @@ class Matmul_ATB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@@ -138,7 +342,6 @@ class Matmul_ATB_3D(torch.autograd.Function):
# A: [m/q^2, n, k/q]
# B: [m/q^2, n, h/q]
# C: [k/q, h/q^2]
- empty_cache()
ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode)
@@ -160,6 +363,7 @@ class Matmul_ATB_3D(torch.autograd.Function):
return out
@staticmethod
+ @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
@@ -180,6 +384,7 @@ class Add_3D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
@@ -206,6 +411,7 @@ class Add_3D(torch.autograd.Function):
return out
@staticmethod
+ @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q]
with torch.no_grad():
@@ -217,8 +423,8 @@ class Add_3D(torch.autograd.Function):
dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank(
- ctx.A_group_parallel_mode) != gpc.get_local_rank(
- ctx.C_group_parallel_mode):
+ ctx.A_group_parallel_mode) != gpc.get_local_rank(
+ ctx.C_group_parallel_mode):
bias_grad = None
return output_grad, bias_grad, None, None, None, None
@@ -227,6 +433,7 @@ class Mul_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A * b`
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
@@ -243,7 +450,7 @@ class Mul_3D(torch.autograd.Function):
# [h/q]
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
- empty_cache()
+ # empty_cache()
ctx.save_for_backward(input_, bias_temp)
out = torch.mul(input_, bias_temp)
@@ -257,6 +464,7 @@ class Mul_3D(torch.autograd.Function):
return out
@staticmethod
+ @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q]
with torch.no_grad():
@@ -272,8 +480,8 @@ class Mul_3D(torch.autograd.Function):
dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank(
- ctx.A_group_parallel_mode) != gpc.get_local_rank(
- ctx.C_group_parallel_mode):
+ ctx.A_group_parallel_mode) != gpc.get_local_rank(
+ ctx.C_group_parallel_mode):
bias_grad = None
return input_grad, bias_grad, None, None, None, None
@@ -282,6 +490,7 @@ class Sum_3D(torch.autograd.Function):
"""Compute the sum of input tensors
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input_: Tensor,
dim: int,
@@ -299,6 +508,7 @@ class Sum_3D(torch.autograd.Function):
return out
@staticmethod
+ @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
with torch.no_grad():
output_grad = output_grad.contiguous()
@@ -315,35 +525,39 @@ class Reduce_3D(torch.autograd.Function):
"""Reduce input tensors
"""
@staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, depth: int,
parallel_mode: ParallelMode) -> Tensor:
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_.clone()
@staticmethod
+ @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
return output_grad, None, None
-class Slice_3D(torch.autograd.Function):
- """Slice input tensor
- """
- @staticmethod
- def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
- parallel_mode: ParallelMode) -> Tensor:
- rank = gpc.get_local_rank(parallel_mode)
- out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
+# class Slice_3D(torch.autograd.Function):
+# """Slice input tensor
+# """
+# @staticmethod
+# @custom_fwd(cast_inputs=torch.float16)
+# def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
+# parallel_mode: ParallelMode) -> Tensor:
+# rank = gpc.get_local_rank(parallel_mode)
+# out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
- ctx.depth = depth
- ctx.parallel_mode = parallel_mode
- ctx.dim = dim
- ctx.input_shape = input_.shape
+# ctx.depth = depth
+# ctx.parallel_mode = parallel_mode
+# ctx.dim = dim
+# ctx.input_shape = input_.shape
- return out
+# return out
- @staticmethod
- def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
- with torch.no_grad():
- input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
- input_grad.reshape(ctx.input_shape)
- return input_grad, None, None, None
+# @staticmethod
+# @custom_bwd
+# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
+# with torch.no_grad():
+# input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
+# input_grad.reshape(ctx.input_shape)
+# return input_grad, None, None, None
diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/nn/layer/parallel_3d/_utils.py
index 3c9236017..ca3b405ea 100644
--- a/colossalai/nn/layer/parallel_3d/_utils.py
+++ b/colossalai/nn/layer/parallel_3d/_utils.py
@@ -3,7 +3,8 @@
import os
-from colossalai.constants import DEPTH_3D
+from colossalai.constants import (DEPTH_3D, INPUT_GROUP_3D, OUTPUT_GROUP_3D,
+ WEIGHT_GROUP_3D)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor
@@ -23,6 +24,10 @@ def get_depth_from_env() -> int:
)
+def get_parallel_mode_from_env(group):
+ return getattr(ParallelMode, os.environ[group])
+
+
def get_last_group(a, b):
mapping = {
ParallelMode.PARALLEL_3D_INPUT: 'A',
@@ -41,6 +46,11 @@ def get_last_group(a, b):
return ParallelMode.PARALLEL_3D_OUTPUT
+def swap_in_out_group():
+ os.environ[INPUT_GROUP_3D], os.environ[OUTPUT_GROUP_3D] = \
+ os.environ[OUTPUT_GROUP_3D], os.environ[INPUT_GROUP_3D]
+
+
def dbg_check_shape(tensor: Tensor, shape: tuple):
rank = gpc.get_global_rank()
if rank == 0:
diff --git a/colossalai/nn/layer/parallel_3d/_vit.py b/colossalai/nn/layer/parallel_3d/_vit.py
index ffe7a146a..46fb83b92 100644
--- a/colossalai/nn/layer/parallel_3d/_vit.py
+++ b/colossalai/nn/layer/parallel_3d/_vit.py
@@ -1,17 +1,20 @@
import math
-from typing import Tuple
+import os
+from typing import Tuple, Optional
import torch
import torch.distributed as dist
+from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
+ WEIGHT_GROUP_3D)
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
+from colossalai.nn.init import init_bias_, init_weight_
from colossalai.utils import checkpoint, get_current_device
from torch import Tensor, dtype, nn
-from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute
-from ..vanilla_vision_transformer.layers import to_2tuple
-from ._utils import get_depth_from_env
+from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute_by_size, to_2tuple
+from ._utils import get_depth_from_env, get_parallel_mode_from_env, get_last_group
from .layers import Linear3D
@@ -32,34 +35,42 @@ class ViTPatchEmbedding3D(nn.Module):
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
+
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
drop_prob: float,
- flatten: bool = True):
+ flatten: bool = True,
+ init_method: str = 'torch'):
super().__init__()
self.depth = get_depth_from_env()
- self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
- self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
- self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
+ self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ self.output_parallel_mode = get_last_group(self.input_parallel_mode,
+ self.weight_parallel_mode)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
+ self.in_chans = in_chans
self.embed_size = embed_size
self.embed_size_per_partition = divide(self.embed_size, self.depth)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
+ self.init_weight = 'torch'
+ self.init_bias = 'torch'
+ if init_method == 'jax':
+ self.init_weight = 'jax_embed'
+ self.init_bias = 'zero'
- with seed(ParallelMode.TENSOR):
- self.proj = nn.Conv2d(in_chans,
- self.embed_size_per_partition,
- kernel_size=patch_size,
- stride=patch_size)
+ self.proj = nn.Conv2d(self.in_chans,
+ self.embed_size_per_partition,
+ kernel_size=patch_size,
+ stride=patch_size)
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.embed_size_per_partition))
@@ -68,23 +79,26 @@ class ViTPatchEmbedding3D(nn.Module):
self.embed_size_per_partition))
self.pos_drop = nn.Dropout(drop_prob)
- self._sync_parameters()
- self.proj.weight.register_hook(self._sync_grad_hook)
- self.proj.bias.register_hook(self._sync_grad_hook)
- self.cls_token.register_hook(self._sync_grad_hook)
- self.pos_embed.register_hook(self._sync_grad_hook)
- self._set_tensor_parallel_attribute()
+ self.reset_parameters(self.init_weight, self.init_bias)
+ self._set_tensor_parallel_attributes()
- def _set_tensor_parallel_attribute(self):
- set_tensor_parallel_attribute(self.proj.weight)
- set_tensor_parallel_attribute(self.proj.bias)
- set_tensor_parallel_attribute(self.cls_token)
- set_tensor_parallel_attribute(self.pos_embed)
+ def _set_tensor_parallel_attributes(self):
+ set_tensor_parallel_attribute_by_size(self.proj.weight, self.in_chans * self.embed_size * self.num_patches)
+ set_tensor_parallel_attribute_by_size(self.proj.bias, self.embed_size)
+ set_tensor_parallel_attribute_by_size(self.cls_token, 1 * 1 * self.embed_size)
+ set_tensor_parallel_attribute_by_size(self.pos_embed, 1 * (self.num_patches + 1) * self.embed_size)
- def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
- return self.input_parallel_mode, self.weight_parallel_mode
+ def reset_parameters(self, init_weight, init_bias):
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.proj.weight)
+ # std = math.sqrt(1.0 / fan_in)
+ # nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
+ # nn.init.zeros_(self.proj.bias)
+ if init_weight != 'torch':
+ init_weight_(self.proj.weight, fan_in, init_method=init_weight)
+ init_bias_(self.pos_embed, fan_in, init_method=init_weight)
+ if init_bias != 'torch':
+ init_bias_(self.proj.bias, fan_in, init_method=init_bias)
- def _sync_parameters(self):
self.to(get_current_device())
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
dist.broadcast(self.proj.weight,
@@ -100,10 +114,11 @@ class ViTPatchEmbedding3D(nn.Module):
dist.broadcast(self.proj.bias,
src=input_src_rank,
group=gpc.get_group(self.input_parallel_mode))
- set_tensor_parallel_attribute(self.proj.weight)
- set_tensor_parallel_attribute(self.proj.bias)
- set_tensor_parallel_attribute(self.cls_token)
- set_tensor_parallel_attribute(self.pos_embed)
+
+ self.proj.weight.register_hook(self._sync_grad_hook)
+ self.proj.bias.register_hook(self._sync_grad_hook)
+ self.cls_token.register_hook(self._sync_grad_hook)
+ self.pos_embed.register_hook(self._sync_grad_hook)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode))
@@ -111,6 +126,12 @@ class ViTPatchEmbedding3D(nn.Module):
return grad
def forward(self, x: Tensor) -> Tensor:
+ # split a partition from inputs
+ x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
+ self.weight_parallel_mode)].contiguous()
+ x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
+ self.input_parallel_mode)].contiguous()
+
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
@@ -118,12 +139,6 @@ class ViTPatchEmbedding3D(nn.Module):
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
- # split a partition from embedded states
- x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
- self.weight_parallel_mode)].contiguous()
- x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
- self.input_parallel_mode)].contiguous()
-
# add cls token & pos embedding
# [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
@@ -158,6 +173,7 @@ class ViTSelfAttention3D(nn.Module):
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
+
def __init__(self,
hidden_size: int,
num_attention_heads: int,
@@ -165,41 +181,52 @@ class ViTSelfAttention3D(nn.Module):
hidden_dropout_prob: float,
dtype: dtype = None,
bias: bool = True,
- checkpoint: bool = False):
+ checkpoint: bool = False,
+ init_method: str = 'torch'):
super().__init__()
self.depth = get_depth_from_env()
- self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
- self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
- self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
+ # self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ # self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ # self.output_parallel_mode = get_last_group(self.input_parallel_mode,
+ # self.weight_parallel_mode)
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.depth)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
+ self.init_weight = 'torch'
+ self.init_bias = 'torch'
+ if init_method == 'jax':
+ self.init_weight = 'jax'
+ self.init_bias = 'zero'
self.query_key_value = Linear3D(self.hidden_size,
3 * self.hidden_size,
- self.input_parallel_mode,
- self.weight_parallel_mode,
+ # self.input_parallel_mode,
+ # self.weight_parallel_mode,
dtype=dtype,
- bias=bias)
+ bias=bias,
+ init_weight=self.init_weight,
+ init_bias=self.init_bias)
self.attention_dropout = nn.Dropout(attention_probs_dropout_prob)
self.dense = Linear3D(self.hidden_size,
self.hidden_size,
- self.output_parallel_mode,
- self.weight_parallel_mode,
+ # self.output_parallel_mode,
+ # self.weight_parallel_mode,
dtype=dtype,
- bias=bias)
+ bias=bias,
+ init_weight=self.init_weight,
+ init_bias=self.init_bias)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
- def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
- return self.input_parallel_mode, self.weight_parallel_mode
+ # def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
+ # return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
- (self.num_attention_heads, 3 * self.attention_head_size)
+ (self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(query_key_value,
@@ -259,6 +286,7 @@ class ViTMLP3D(nn.Module):
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
+
def __init__(self,
hidden_size: int,
mlp_ratio: int,
@@ -266,33 +294,41 @@ class ViTMLP3D(nn.Module):
hidden_act: str = 'gelu',
dtype: dtype = None,
bias: bool = True,
- checkpoint: bool = False):
+ checkpoint: bool = False,
+ init_method: str = 'torch'):
super().__init__()
- self.depth = get_depth_from_env()
- self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
- self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
- self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
+ # self.depth = get_depth_from_env()
+ # self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ # self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ # self.output_parallel_mode = get_last_group(self.input_parallel_mode,
+ # self.weight_parallel_mode)
self.hidden_size = hidden_size
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
+ self.init_weight = init_method
+ self.init_bias = init_method
self.dense_1 = Linear3D(self.hidden_size,
self.mlp_ratio * self.hidden_size,
- self.input_parallel_mode,
- self.weight_parallel_mode,
+ # self.input_parallel_mode,
+ # self.weight_parallel_mode,
dtype=dtype,
- bias=bias)
+ bias=bias,
+ init_weight=self.init_weight,
+ init_bias=self.init_bias)
self.activation_func = ACT2FN[hidden_act]
self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size,
self.hidden_size,
- self.output_parallel_mode,
- self.weight_parallel_mode,
+ # self.output_parallel_mode,
+ # self.weight_parallel_mode,
dtype=dtype,
- bias=bias)
+ bias=bias,
+ init_weight=self.init_weight,
+ init_bias=self.init_bias)
self.dropout = nn.Dropout(hidden_dropout_prob)
- def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
- return self.input_parallel_mode, self.weight_parallel_mode
+ # def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
+ # return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
@@ -331,37 +367,46 @@ class ViTHead3D(nn.Module):
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
+
def __init__(self,
in_features: int,
num_classes: int,
dtype: dtype = None,
- bias: bool = True):
+ bias: bool = True,
+ init_method: str = 'torch'):
super().__init__()
- self.depth = get_depth_from_env()
- self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
- self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
- self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
+ # self.depth = get_depth_from_env()
+ # self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ # self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ # self.output_parallel_mode = get_last_group(self.input_parallel_mode,
+ # self.weight_parallel_mode)
self.in_features = in_features
self.num_classes = num_classes
- out_features = math.ceil(self.num_classes /
- (self.depth**2)) * (self.depth**2)
- self.num_classes_per_partition = divide(self.num_classes, self.depth)
- self.linear = Linear3D(self.in_features,
- out_features,
- self.input_parallel_mode,
- self.weight_parallel_mode,
- dtype=dtype,
- bias=bias)
+ # out_features = math.ceil(self.num_classes /
+ # (self.depth**2)) * (self.depth**2)
+ # self.num_classes_per_partition = divide(self.num_classes, self.depth)
+ self.init_weight = 'torch'
+ self.init_bias = 'torch'
+ if init_method == 'jax':
+ self.init_weight = 'zero'
+ self.init_bias = 'zero'
- def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
- return self.linear.groups_for_next_layer()
+ self.linear = Linear3D(self.in_features,
+ self.num_classes,
+ # self.input_parallel_mode,
+ # self.weight_parallel_mode,
+ dtype=dtype,
+ bias=bias,
+ init_weight=self.init_weight,
+ init_bias=self.init_bias)
def forward(self, x: Tensor) -> Tensor:
# [b/q^2, s, h/q] --> [b/q^2, h/q]
x = x[:, 0]
# [b/q^2, h/q] --> [b/q^2, c/q]
x = self.linear(x)
- return x[:, :self.num_classes_per_partition]
+ # return x[:, :self.num_classes_per_partition]
+ return x
def extra_repr(self):
return 'in_features={}, num_classes={}'.format(self.in_features,
diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py
index c6d631008..60e4a2c8a 100644
--- a/colossalai/nn/layer/parallel_3d/layers.py
+++ b/colossalai/nn/layer/parallel_3d/layers.py
@@ -2,19 +2,28 @@
# -*- encoding: utf-8 -*-
import math
+import os
from typing import Tuple
import torch
+import torch.distributed as dist
import torch.nn as nn
+from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
+ WEIGHT_GROUP_3D)
from colossalai.context import ParallelMode, seed
+from colossalai.core import global_context as gpc
+from colossalai.nn.init import init_bias_, init_weight_
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
+from torch.nn import init as init
-from .._common_utils import divide, set_tensor_parallel_attribute
-from ._operation import Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D
-from ._utils import get_depth_from_env, get_last_group
+from .._common_utils import divide, set_tensor_parallel_attribute_by_size
+from ._operation import (Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D, layer_norm_3d,
+ linear_3d)
+from ._utils import (get_depth_from_env, get_last_group,
+ get_parallel_mode_from_env, swap_in_out_group)
@LAYERS.register_module
@@ -22,20 +31,19 @@ class LayerNorm3D(nn.Module):
def __init__(
self,
normalized_shape: int,
- input_parallel_mode: ParallelMode,
- weight_parallel_mode: ParallelMode,
+ # input_parallel_mode: ParallelMode,
+ # weight_parallel_mode: ParallelMode,
eps: float = 1e-12,
dtype: dtype = None,
):
super().__init__()
- self.input_parallel_mode = input_parallel_mode
- self.weight_parallel_mode = weight_parallel_mode
+ self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
self.depth = get_depth_from_env()
self.normalized_shape = normalized_shape
- self.normalized_shape_per_partition = divide(normalized_shape,
- self.depth**2)
+ self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition,
@@ -49,37 +57,40 @@ class LayerNorm3D(nn.Module):
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
- set_tensor_parallel_attribute(self.weight)
- set_tensor_parallel_attribute(self.bias)
-
- def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
- return self.input_parallel_mode, self.weight_parallel_mode
+ set_tensor_parallel_attribute_by_size(self.weight, self.normalized_shape)
+ set_tensor_parallel_attribute_by_size(self.bias, self.normalized_shape)
def reset_parameters(self):
- nn.init.zeros_(self.bias)
- nn.init.ones_(self.weight)
+ init.zeros_(self.bias)
+ init.ones_(self.weight)
def forward(self, input_: Tensor) -> Tensor:
- '''x = weight * (x - mean) / sqrt(var + eps) + bias'''
- # input: [m/q^2, n, h/q]
- # [m/q^2, n, 1]
- mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
- True) / self.normalized_shape
- # [m/q^2, n, 1]
- var = (input_ - mean).pow(2)
- var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
- True) / self.normalized_shape
+ # '''x = weight * (x - mean) / sqrt(var + eps) + bias'''
+ # # input: [m/q^2, n, h/q]
+ # # [m/q^2, n, 1]
+ # mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
+ # True) / self.normalized_shape
+ # # [m/q^2, n, 1]
+ # var = (input_ - mean).pow(2)
+ # var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
+ # True) / self.normalized_shape
- output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
- output = Mul_3D.apply(output, self.weight, self.depth,
- self.input_parallel_mode,
- self.weight_parallel_mode,
- self.output_parallel_mode)
- output = Add_3D.apply(output, self.bias, self.depth,
- self.input_parallel_mode,
- self.weight_parallel_mode,
- self.output_parallel_mode)
- return output
+ # output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
+ # output = Mul_3D.apply(output, self.weight, self.depth,
+ # self.input_parallel_mode,
+ # self.weight_parallel_mode,
+ # self.output_parallel_mode)
+ # output = Add_3D.apply(output, self.bias, self.depth,
+ # self.input_parallel_mode,
+ # self.weight_parallel_mode,
+ # self.output_parallel_mode)
+ # return output
+ return layer_norm_3d.apply(input_, self.weight, self.bias,
+ self.normalized_shape,
+ self.variance_epsilon,
+ self.input_parallel_mode,
+ self.weight_parallel_mode,
+ self.output_parallel_mode)
def extra_repr(self):
return '{}, eps={}'.format(self.normalized_shape,
@@ -88,33 +99,36 @@ class LayerNorm3D(nn.Module):
@LAYERS.register_module
class Linear3D(nn.Module):
- def __init__(self,
- in_features: int,
- out_features: int,
- input_parallel_mode: ParallelMode,
- weight_parallel_mode: ParallelMode,
- bias: bool = True,
- dtype: dtype = None):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ # input_parallel_mode: ParallelMode,
+ # weight_parallel_mode: ParallelMode,
+ bias: bool = True,
+ dtype: dtype = None,
+ init_weight: str = 'torch',
+ init_bias: str = 'torch'):
super().__init__()
self.in_features = in_features
self.out_features = out_features
- self.input_parallel_mode = input_parallel_mode
- self.weight_parallel_mode = weight_parallel_mode
+ self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
- self.with_bias = bias
+ # self.with_bias = bias
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
- self.out_features_per_partition = divide(out_features, self.depth**2)
+ self.out_features_per_partition = divide(out_features, self.depth)
- # [k/q, h/q^2]
+ # [k/q, h/q]
self.weight = Parameter(
torch.empty(self.in_features_per_partition,
self.out_features_per_partition,
device=get_current_device(),
dtype=dtype))
- # [h/q^2]
+ # [h/q]
if bias:
self.bias = Parameter(
torch.zeros(self.out_features_per_partition,
@@ -123,49 +137,54 @@ class Linear3D(nn.Module):
else:
self.register_parameter('bias', None)
- self.reset_parameters()
+ self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
+ swap_in_out_group()
def _set_tensor_parallel_attributes(self):
- set_tensor_parallel_attribute(self.weight)
+ set_tensor_parallel_attribute_by_size(self.weight, self.in_features * self.out_features)
if self.bias is not None:
- set_tensor_parallel_attribute(self.bias)
+ set_tensor_parallel_attribute_by_size(self.bias, self.out_features)
- def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
- return self.output_parallel_mode, self.weight_parallel_mode
-
- def reset_parameters(self):
+ def reset_parameters(self, init_weight, init_bias) -> None:
# setting
- fan_in = self.in_features
- a = math.sqrt(5)
- nonlinearity = 'leaky_relu'
+ fan_in, fan_out = self.in_features, self.out_features
+ weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
+ output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
# init weight
- std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
- bound = math.sqrt(3.0) * std
- with seed(ParallelMode.TENSOR):
- nn.init.uniform_(self.weight, -bound, bound)
-
+ init_weight_(self.weight, fan_in, fan_out, init_method=init_weight)
+ dist.broadcast(self.weight,
+ src=weight_src_rank,
+ group=gpc.get_group(self.weight_parallel_mode))
# init bias
- if self.with_bias:
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
- with seed(ParallelMode.TENSOR):
- nn.init.uniform_(self.bias, -bound, bound)
+ if self.bias is not None:
+ init_bias_(self.bias, fan_in, init_method=init_bias)
+ dist.broadcast(self.bias,
+ src=weight_src_rank,
+ group=gpc.get_group(self.weight_parallel_mode))
+ dist.broadcast(self.bias,
+ src=output_src_rank,
+ group=gpc.get_group(self.output_parallel_mode))
def forward(self, input_: Tensor) -> Tensor:
- # input: [m/q^2, n, k/q]
- # output: [m/q^2, n, h/q]
- output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
- self.input_parallel_mode,
- self.weight_parallel_mode,
- self.output_parallel_mode)
+ # # input: [m/q^2, n, k/q]
+ # # output: [m/q^2, n, h/q]
+ # output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
+ # self.input_parallel_mode,
+ # self.weight_parallel_mode,
+ # self.output_parallel_mode)
- if self.with_bias:
- output = Add_3D.apply(output, self.bias, self.depth,
- self.output_parallel_mode,
- self.weight_parallel_mode,
- self.input_parallel_mode)
- return output
+ # if self.bias is not None:
+ # output = Add_3D.apply(output, self.bias, self.depth,
+ # self.output_parallel_mode,
+ # self.weight_parallel_mode,
+ # self.input_parallel_mode)
+ # return output
+ return linear_3d.apply(input_, self.weight, self.bias,
+ self.input_parallel_mode,
+ self.weight_parallel_mode,
+ self.output_parallel_mode)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
diff --git a/colossalai/nn/layer/parallel_vision_transformer/__init__.py b/colossalai/nn/layer/parallel_vision_transformer/__init__.py
deleted file mode 100644
index 8adf9eb30..000000000
--- a/colossalai/nn/layer/parallel_vision_transformer/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .layers import ViTBlock
-
-__all__ = ['ViTBlock']
diff --git a/colossalai/nn/layer/parallel_vision_transformer/layers.py b/colossalai/nn/layer/parallel_vision_transformer/layers.py
deleted file mode 100644
index 8624f7f66..000000000
--- a/colossalai/nn/layer/parallel_vision_transformer/layers.py
+++ /dev/null
@@ -1,59 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from torch import nn as nn
-
-from colossalai.builder import build_layer
-from colossalai.registry import LAYERS
-
-
-@LAYERS.register_module
-class ViTBlock(nn.Module):
- """Vision Transformer block
-
- :param attention_cfg: config of attention layer
- :type attention_cfg: dict
- :param droppath_cfg: config of drop path
- :type droppath_cfg: dict
- :param mlp_cfg: config of MLP layer
- :type mlp_cfg: dict
- :param norm_cfg: config of normlization layer
- :type norm_cfg: dict
- """
-
- def __init__(self,
- attention_cfg: dict,
- droppath_cfg: dict,
- mlp_cfg: dict,
- norm_cfg: dict,
- ):
- super().__init__()
- self.norm1 = build_layer(norm_cfg)
- self.attn = build_layer(attention_cfg)
- self.drop_path = build_layer(
- droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
- self.norm2 = build_layer(norm_cfg)
- self.mlp = build_layer(mlp_cfg)
-
- def forward(self, x):
- x = x + self.drop_path(self.attn(self.norm1(x)))
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- # x_ = x
- # x_ = self.norm1(x_)
- # if self.checkpoint:
- # x_ = checkpoint(self.attn, x_)
- # else:
- # x_ = self.attn(x_)
- # x_ = self.drop_path(x_)
- # x = x + x_
- #
- # x_ = x
- # x_ = self.norm2(x_)
- # if self.checkpoint:
- # x_ = checkpoint(self.mlp, x_)
- # else:
- # x_ = self.mlp(x_)
- # x_ = self.drop_path(x_)
- # x = x + x_
- return x
diff --git a/colossalai/nn/layer/vanilla_resnet/__init__.py b/colossalai/nn/layer/vanilla_resnet/__init__.py
deleted file mode 100644
index 289b8749e..000000000
--- a/colossalai/nn/layer/vanilla_resnet/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .basic_block import ResNetBasicBlock
-from .bottleneck import ResNetBottleneck
-from .reslayer import ResLayer
-
-__all__ = ['ResLayer', 'ResNetBottleneck', 'ResNetBasicBlock']
diff --git a/colossalai/nn/layer/vanilla_resnet/basic_block.py b/colossalai/nn/layer/vanilla_resnet/basic_block.py
deleted file mode 100644
index 320dac2fd..000000000
--- a/colossalai/nn/layer/vanilla_resnet/basic_block.py
+++ /dev/null
@@ -1,64 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from typing import Optional, Callable
-
-import torch.nn as nn
-from torch import Tensor
-
-from colossalai.registry import LAYERS
-from .conv import conv3x3
-
-
-@LAYERS.register_module
-class ResNetBasicBlock(nn.Module):
- """Basic ResNet block
- """
- expansion: int = 1
-
- def __init__(
- self,
- inplanes: int,
- planes: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- groups: int = 1,
- base_width: int = 64,
- dilation: int = 1,
- norm_layer: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super().__init__()
- if norm_layer is None:
- norm_layer = nn.BatchNorm2d
- if groups != 1 or base_width != 64:
- raise ValueError(
- 'BasicBlock only supports groups=1 and base_width=64')
- if dilation > 1:
- raise NotImplementedError(
- "Dilation > 1 not supported in BasicBlock")
- # Both self.conv1 and self.downsample layers downsample the input when stride != 1
- self.conv1 = conv3x3(inplanes, planes, stride)
- self.bn1 = norm_layer(planes)
- self.relu = nn.ReLU(inplace=True)
- self.conv2 = conv3x3(planes, planes)
- self.bn2 = norm_layer(planes)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x: Tensor) -> Tensor:
- identity = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out += identity
- out = self.relu(out)
-
- return out
diff --git a/colossalai/nn/layer/vanilla_resnet/bottleneck.py b/colossalai/nn/layer/vanilla_resnet/bottleneck.py
deleted file mode 100644
index d75f9534b..000000000
--- a/colossalai/nn/layer/vanilla_resnet/bottleneck.py
+++ /dev/null
@@ -1,69 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from typing import Optional, Callable
-
-import torch.nn as nn
-from torch import Tensor
-
-from colossalai.registry import LAYERS
-from .conv import conv3x3, conv1x1
-
-
-@LAYERS.register_module
-class ResNetBottleneck(nn.Module):
- # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
- # while original implementation places the stride at the first 1x1 convolution(self.conv1)
- # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
- # This variant is also known as ResNet V1.5 and improves accuracy according to
- # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
-
- expansion: int = 4
-
- def __init__(
- self,
- inplanes: int,
- planes: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- groups: int = 1,
- base_width: int = 64,
- dilation: int = 1,
- norm_layer: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super().__init__()
- if norm_layer is None:
- norm_layer = nn.BatchNorm2d
- width = int(planes * (base_width / 64.)) * groups
- # Both self.conv2 and self.downsample layers downsample the input when stride != 1
- self.conv1 = conv1x1(inplanes, width)
- self.bn1 = norm_layer(width)
- self.conv2 = conv3x3(width, width, stride, groups, dilation)
- self.bn2 = norm_layer(width)
- self.conv3 = conv1x1(width, planes * self.expansion)
- self.bn3 = norm_layer(planes * self.expansion)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x: Tensor) -> Tensor:
- identity = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out += identity
- out = self.relu(out)
-
- return out
diff --git a/colossalai/nn/layer/vanilla_resnet/conv.py b/colossalai/nn/layer/vanilla_resnet/conv.py
deleted file mode 100644
index c918d94c4..000000000
--- a/colossalai/nn/layer/vanilla_resnet/conv.py
+++ /dev/null
@@ -1,15 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch.nn as nn
-
-
-def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
- """3x3 convolution with padding"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=dilation, groups=groups, bias=False, dilation=dilation)
-
-
-def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
- """1x1 convolution"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
diff --git a/colossalai/nn/layer/vanilla_resnet/reslayer.py b/colossalai/nn/layer/vanilla_resnet/reslayer.py
deleted file mode 100644
index 4e1b48c5e..000000000
--- a/colossalai/nn/layer/vanilla_resnet/reslayer.py
+++ /dev/null
@@ -1,63 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch.nn as nn
-
-from colossalai.registry import LAYERS
-from .conv import conv1x1
-
-
-@LAYERS.register_module
-class ResLayer(nn.Module):
-
- def __init__(self,
- block_type: str,
- norm_layer_type: str,
- inplanes: int,
- planes: int,
- blocks: int,
- groups: int,
- base_width: int,
- stride: int = 1,
- dilation: int = 1,
- dilate: bool = False,
- ):
- super().__init__()
- self.block = LAYERS.get_module(block_type)
- self.norm_layer = LAYERS.get_module(norm_layer_type)
- self.inplanes = inplanes
- self.planes = planes
- self.blocks = blocks
- self.groups = groups
- self.dilation = dilation
- self.base_width = base_width
- self.dilate = dilate
- self.stride = stride
- self.layer = self._make_layer()
-
- def _make_layer(self):
- norm_layer = self.norm_layer
- downsample = None
- previous_dilation = self.dilation
- if self.dilate:
- self.dilation *= self.stride
- self.stride = 1
- if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
- downsample = nn.Sequential(
- conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
- norm_layer(self.planes * self.block.expansion),
- )
-
- layers = []
- layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
- self.base_width, previous_dilation, norm_layer))
- self.inplanes = self.planes * self.block.expansion
- for _ in range(1, self.blocks):
- layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
- base_width=self.base_width, dilation=self.dilation,
- norm_layer=norm_layer))
-
- return nn.Sequential(*layers)
-
- def forward(self, x):
- return self.layer(x)
diff --git a/colossalai/nn/layer/vanilla_vision_transformer/__init__.py b/colossalai/nn/layer/vanilla_vision_transformer/__init__.py
deleted file mode 100644
index 90d614e0a..000000000
--- a/colossalai/nn/layer/vanilla_vision_transformer/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .layers import (VanillaViTBlock, VanillaViTMLP, VanillaViTPatchEmbedding,
- VanillaViTAttention, VanillaViTDropPath, VanillaViTHead)
-
-__all__ = [
- 'VanillaViTBlock', 'VanillaViTMLP', 'VanillaViTPatchEmbedding',
- 'VanillaViTAttention', 'VanillaViTDropPath', 'VanillaViTHead'
-]
diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py
index 6015c55c6..19c83b747 100644
--- a/colossalai/nn/loss/__init__.py
+++ b/colossalai/nn/loss/__init__.py
@@ -1,4 +1,3 @@
-from .base_loss import BaseLoss
from .cross_entropy_2d import CrossEntropyLoss2D
from .cross_entropy_2p5d import CrossEntropyLoss2p5D
from .cross_entropy_3d import CrossEntropyLoss3D
diff --git a/colossalai/nn/loss/base_loss.py b/colossalai/nn/loss/base_loss.py
deleted file mode 100644
index bf5bbe6b2..000000000
--- a/colossalai/nn/loss/base_loss.py
+++ /dev/null
@@ -1,13 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from abc import ABC, abstractmethod
-
-
-class BaseLoss(ABC):
- """Absctract loss class
- """
-
- @abstractmethod
- def calc_loss(self, *args, **kwargs):
- pass
diff --git a/colossalai/nn/loss/cross_entropy_1d.py b/colossalai/nn/loss/cross_entropy_1d.py
deleted file mode 100644
index 667c00734..000000000
--- a/colossalai/nn/loss/cross_entropy_1d.py
+++ /dev/null
@@ -1,120 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch
-import torch.nn.functional as F
-from torch.nn.modules.loss import _Loss
-
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_1d._utils import vocab_range_from_per_partition_vocab_size
-
-
-class _VocabParallelCrossEntropy_1D(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, vocab_parallel_logits, target):
- # Maximum value along vocab dimension across all GPUs.
- logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
- torch.distributed.all_reduce(logits_max,
- op=torch.distributed.ReduceOp.MAX,
- group=gpc.get_group(ParallelMode.PARALLEL_1D))
- # Subtract the maximum value.
- vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
-
- # Get the partition's vocab indecies
- partition_vocab_size = vocab_parallel_logits.size()[-1]
- rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
- world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
- vocab_start_index, vocab_end_index = vocab_range_from_per_partition_vocab_size(
- partition_vocab_size, rank, world_size)
-
- # Create a mask of valid vocab ids (1 means it needs to be masked).
- target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
- masked_target = target.clone() - vocab_start_index
- masked_target[target_mask] = 0
-
- # Get predicted-logits = logits[target].
- # For Simplicity, we convert logits to a 2-D tensor with size
- # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
- logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
- masked_target_1d = masked_target.view(-1)
- arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
- device=logits_2d.device)
- predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
- predicted_logits_1d = predicted_logits_1d.clone().contiguous()
- predicted_logits = predicted_logits_1d.view_as(target)
- predicted_logits[target_mask] = 0.0
- # All reduce is needed to get the chunks from other GPUs.
- torch.distributed.all_reduce(predicted_logits,
- op=torch.distributed.ReduceOp.SUM,
- group=gpc.get_group(ParallelMode.PARALLEL_1D))
-
- # Sum of exponential of logits along vocab dimension across all GPUs.
- exp_logits = vocab_parallel_logits
- torch.exp(vocab_parallel_logits, out=exp_logits)
- sum_exp_logits = exp_logits.sum(dim=-1)
- torch.distributed.all_reduce(sum_exp_logits,
- op=torch.distributed.ReduceOp.SUM,
- group=gpc.get_group(ParallelMode.PARALLEL_1D))
-
- # Loss = log(sum(exp(logits))) - predicted-logit.
- loss = torch.log(sum_exp_logits) - predicted_logits
-
- # Store softmax, target-mask and masked-target for backward pass.
- exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
- ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
-
- return loss
-
- @staticmethod
- def backward(ctx, grad_output):
- # Retreive tensors from the forward path.
- softmax, target_mask, masked_target_1d = ctx.saved_tensors
-
- # All the inputs have softmax as thier gradient.
- grad_input = softmax
- # For simplicity, work with the 2D gradient.
- partition_vocab_size = softmax.size()[-1]
- grad_2d = grad_input.view(-1, partition_vocab_size)
-
- # Add the gradient from matching classes.
- arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
- device=grad_2d.device)
- grad_2d[arange_1d, masked_target_1d] -= (
- 1.0 - target_mask.view(-1).float())
-
- # Finally elementwise multiplication with the output gradients.
- grad_input.mul_(grad_output.unsqueeze(dim=-1))
-
- return grad_input, None
-
-
-class LmLoss1D(_Loss):
-
- def forward(self, lm_logits, lm_labels, loss_mask):
- lm_loss = _VocabParallelCrossEntropy_1D.apply(lm_logits, lm_labels)
- lm_loss = torch.sum(
- lm_loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
- return lm_loss
-
-
-class SopLoss1D(_Loss):
-
- def forward(self, sop_logits, sentence_order):
- sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
- sentence_order.view(-1),
- ignore_index=-1)
- return sop_loss
-
-
-class BERTDualHeadLoss(_Loss):
-
- def __init__(self):
- self.lm_loss = LmLoss1D()
- self.sop_loss = SopLoss1D()
-
- def forward(self, lm_logits, sop_logits, lm_labels, loss_mask, sentence_order):
- lm_loss = self.lm_loss(lm_logits, lm_labels, loss_mask)
- sop_loss = self.sop_loss(sop_logits, sentence_order)
- return lm_loss + sop_loss
diff --git a/colossalai/nn/loss/cross_entropy_2d.py b/colossalai/nn/loss/cross_entropy_2d.py
index fe7ca6aa8..3bb5712aa 100644
--- a/colossalai/nn/loss/cross_entropy_2d.py
+++ b/colossalai/nn/loss/cross_entropy_2d.py
@@ -7,18 +7,18 @@ from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
+from torch.cuda.amp import custom_bwd, custom_fwd
class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ###
@staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets):
# logits: [b/q, h/q]
# labels: [b/q]
- # loss: [b/q]
- # vocab_parallel_logits: [b/q, s, v/q]
- # target: [b/q, s]
+
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(
logits_max,
@@ -58,6 +58,7 @@ class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
return loss
@staticmethod
+ @custom_bwd
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
@@ -91,12 +92,14 @@ class _ReduceByColumn(torch.autograd.Function):
return input_
@staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
return input_
@staticmethod
+ @custom_bwd
def backward(ctx, grad_output):
return grad_output
diff --git a/colossalai/nn/loss/cross_entropy_3d.py b/colossalai/nn/loss/cross_entropy_3d.py
index b1ef7731b..97409322d 100644
--- a/colossalai/nn/loss/cross_entropy_3d.py
+++ b/colossalai/nn/loss/cross_entropy_3d.py
@@ -1,32 +1,20 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+import os
+
import torch
import torch.distributed as dist
-from torch.nn.modules.loss import _Loss
-
-from colossalai.communication import all_gather
+from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
+ WEIGHT_GROUP_3D)
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_3d._operation import Reduce_3D
-from colossalai.nn.layer.parallel_3d._utils import get_last_group, get_depth_from_env
+from colossalai.nn.layer.parallel_3d._utils import (get_depth_from_env,
+ get_last_group,
+ get_parallel_mode_from_env)
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
-
-
-def accuracy_3d(output, target, input_parallel_mode, weight_parallel_mode):
- depth = get_depth_from_env()
- output_parallel_mode = get_last_group(input_parallel_mode,
- weight_parallel_mode)
- j = gpc.get_local_rank(input_parallel_mode)
- i = gpc.get_local_rank(weight_parallel_mode)
- target = torch.chunk(target, depth, dim=0)[i]
- target = torch.chunk(target, depth, dim=0)[j]
- output = all_gather(output, -1, output_parallel_mode)
- prediction = torch.argmax(output, dim=-1)
- correct = torch.sum(prediction == target)
- dist.all_reduce(correct, group=gpc.get_group(input_parallel_mode))
- dist.all_reduce(correct, group=gpc.get_group(weight_parallel_mode))
- return correct.item()
+from torch.nn.modules.loss import _Loss
class _ParallelCrossEntropyLossFunction_3D(torch.autograd.Function):
@@ -112,16 +100,18 @@ class CrossEntropyLoss3D(_Loss):
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
- def __init__(self,
- input_parallel_mode,
- weight_parallel_mode,
- reduction=True):
+ def __init__(
+ self,
+ # input_parallel_mode,
+ # weight_parallel_mode,
+ reduction=True,
+ label_smoothing=0.0):
super().__init__()
self.depth = get_depth_from_env()
- self.input_parallel_mode = input_parallel_mode
- self.weight_parallel_mode = weight_parallel_mode
- self.output_parallel_mode = get_last_group(input_parallel_mode,
- weight_parallel_mode)
+ self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ self.output_parallel_mode = get_last_group(self.input_parallel_mode,
+ self.weight_parallel_mode)
self.input_rank = gpc.get_local_rank(self.input_parallel_mode)
self.weight_rank = gpc.get_local_rank(self.weight_parallel_mode)
self.reduction_mean = reduction
@@ -141,53 +131,53 @@ class CrossEntropyLoss3D(_Loss):
return loss
-@LOSSES.register_module
-class LabelSmoothingCrossEntropy3D(_Loss):
- """
- NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy
+# @LOSSES.register_module
+# class LabelSmoothingCrossEntropy3D(_Loss):
+# """
+# NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy
- :param input_parallel_mode: parallel mode for input tensor
- :type input_parallel_mode: ParallelMode
- :param weight_parallel_mode: parallel mode for weight
- :type weight_parallel_mode: ParallelMode
- :param smoothing: label smoothing value, defaults to 0.1
- :type smoothing: float
- :param reduction: whether to average the loss, defaults to True
- :type reduction: bool, optional
- """
- def __init__(self,
- input_parallel_mode,
- weight_parallel_mode,
- smoothing=0.1,
- reduction=True):
- super().__init__()
- assert smoothing < 1.0
- self.smoothing = smoothing
- self.confidence = 1. - smoothing
- self.depth = get_depth_from_env()
- self.input_parallel_mode = input_parallel_mode
- self.weight_parallel_mode = weight_parallel_mode
- self.output_parallel_mode = get_last_group(input_parallel_mode,
- weight_parallel_mode)
- self.reduction_mean = reduction
+# :param input_parallel_mode: parallel mode for input tensor
+# :type input_parallel_mode: ParallelMode
+# :param weight_parallel_mode: parallel mode for weight
+# :type weight_parallel_mode: ParallelMode
+# :param smoothing: label smoothing value, defaults to 0.1
+# :type smoothing: float
+# :param reduction: whether to average the loss, defaults to True
+# :type reduction: bool, optional
+# """
+# def __init__(self,
+# input_parallel_mode,
+# weight_parallel_mode,
+# smoothing=0.1,
+# reduction=True):
+# super().__init__()
+# assert smoothing < 1.0
+# self.smoothing = smoothing
+# self.confidence = 1. - smoothing
+# self.depth = get_depth_from_env()
+# self.input_parallel_mode = input_parallel_mode
+# self.weight_parallel_mode = weight_parallel_mode
+# self.output_parallel_mode = get_last_group(input_parallel_mode,
+# weight_parallel_mode)
+# self.reduction_mean = reduction
- def forward(self, logits, targets):
- # split label partition from the entire batch
- j = gpc.get_local_rank(self.input_parallel_mode)
- i = gpc.get_local_rank(self.weight_parallel_mode)
- targets = torch.chunk(targets, self.depth, dim=0)[i]
- targets = torch.chunk(targets, self.depth, dim=0)[j]
- exp_logits = torch.exp(logits)
- sum_exp_logits = Sum3D.apply(exp_logits, -1, depth,
- self.output_parallel_mode, False)
- log_probs = torch.log(sum_exp_logits) - logits
- nll_loss = _ParallelCrossEntropyLossFunction_3D.apply(
- logits, targets, self.depth, self.output_parallel_mode)
- smooth_loss = -log_probs.mean(dim=-1)
- loss = self.confidence * nll_loss + self.smoothing * smooth_loss
- if self.reduction_mean:
- loss = loss.sum()
- loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
- loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
- loss /= batch_size
- return loss
+# def forward(self, logits, targets):
+# # split label partition from the entire batch
+# j = gpc.get_local_rank(self.input_parallel_mode)
+# i = gpc.get_local_rank(self.weight_parallel_mode)
+# targets = torch.chunk(targets, self.depth, dim=0)[i]
+# targets = torch.chunk(targets, self.depth, dim=0)[j]
+# exp_logits = torch.exp(logits)
+# sum_exp_logits = Sum3D.apply(exp_logits, -1, depth,
+# self.output_parallel_mode, False)
+# log_probs = torch.log(sum_exp_logits) - logits
+# nll_loss = _ParallelCrossEntropyLossFunction_3D.apply(
+# logits, targets, self.depth, self.output_parallel_mode)
+# smooth_loss = -log_probs.mean(dim=-1)
+# loss = self.confidence * nll_loss + self.smoothing * smooth_loss
+# if self.reduction_mean:
+# loss = loss.sum()
+# loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
+# loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
+# loss /= batch_size
+# return loss
diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py
index 173d2f52c..0f7bc1df6 100644
--- a/colossalai/nn/lr_scheduler/delayed.py
+++ b/colossalai/nn/lr_scheduler/delayed.py
@@ -48,8 +48,10 @@ class DelayerScheduler(_LRScheduler):
if self.finished:
if epoch is None:
self.after_scheduler.step(None)
+ self._last_lr = self.after_scheduler.get_last_lr()
else:
self.after_scheduler.step(epoch - self.delay_epochs)
+ self._last_lr = self.after_scheduler.get_last_lr()
else:
return super(DelayerScheduler, self).step(epoch)
@@ -66,6 +68,7 @@ class WarmupScheduler(_LRScheduler):
:param last_epoch: The index of last epoch, defaults to -1
:type last_epoch: int, optional
"""
+
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
self.warmup_epochs = int(warmup_epochs)
self.after_scheduler = after_scheduler
@@ -85,8 +88,10 @@ class WarmupScheduler(_LRScheduler):
if self.finished:
if epoch is None:
self.after_scheduler.step(None)
+ self._last_lr = self.after_scheduler.get_last_lr()
else:
self.after_scheduler.step(epoch - self.warmup_epochs)
+ self._last_lr = self.after_scheduler.get_last_lr()
else:
return super().step(epoch)
@@ -136,7 +141,9 @@ class WarmupDelayerScheduler(_LRScheduler):
if self.finished:
if epoch is None:
self.after_scheduler.step(None)
+ self._last_lr = self.after_scheduler.get_last_lr()
else:
self.after_scheduler.step(epoch - self.warmup_epochs)
+ self._last_lr = self.after_scheduler.get_last_lr()
else:
return super().step(epoch)
diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py
index 5def4a1fa..cdb89b53f 100644
--- a/colossalai/nn/lr_scheduler/multistep.py
+++ b/colossalai/nn/lr_scheduler/multistep.py
@@ -12,7 +12,6 @@ class MultiStepLR(_MultiStepLR):
number of epoch reaches one of the milestones. Notice that such decay can
happen simultaneously with other changes to the learning rate from outside
this scheduler. When last_epoch=-1, sets initial lr as lr.
-
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@@ -34,7 +33,6 @@ class MultiStepLR(_MultiStepLR):
@LR_SCHEDULERS.register_module
class MultiStepWarmupLR(WarmupScheduler):
"""Multi-step laerning rate scheduler with warmup.
-
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
diff --git a/colossalai/nn/lr_scheduler/onecycle.py b/colossalai/nn/lr_scheduler/onecycle.py
index 743855470..4384e61e2 100644
--- a/colossalai/nn/lr_scheduler/onecycle.py
+++ b/colossalai/nn/lr_scheduler/onecycle.py
@@ -12,28 +12,21 @@ class OneCycleLR(_OneCycleLR):
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
-
The 1cycle learning rate policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
-
This scheduler is not chainable.
-
Note also that the total number of steps in the cycle can be determined in one
of two ways (listed in order of precedence):
-
#. A value for total_steps is explicitly provided.
#. A number of epochs (epochs) and a number of steps per epoch
(steps_per_epoch) are provided.
In this case, the number of total steps is inferred by
total_steps = epochs * steps_per_epoch
-
You must either provide a value for total_steps or provide a value for both
epochs and steps_per_epoch.
-
The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
claims that "unpublished work has shown even better results by using only two phases". To
mimic the behaviour of the original paper instead, set ``three_phase=True``.
-
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@@ -71,7 +64,6 @@ class OneCycleLR(_OneCycleLR):
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning, defaults to -1
:type last_epoch: int, optional
-
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120
"""
diff --git a/colossalai/nn/lr_scheduler/poly.py b/colossalai/nn/lr_scheduler/poly.py
index ee77b2f9b..ae9c1d2d2 100644
--- a/colossalai/nn/lr_scheduler/poly.py
+++ b/colossalai/nn/lr_scheduler/poly.py
@@ -7,7 +7,6 @@ from .delayed import WarmupScheduler
@LR_SCHEDULERS.register_module
class PolynomialLR(_LRScheduler):
"""Polynomial learning rate scheduler.
-
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@@ -43,7 +42,6 @@ class PolynomialLR(_LRScheduler):
@LR_SCHEDULERS.register_module
class PolynomialWarmupLR(WarmupScheduler):
"""Polynomial learning rate scheduler with warmup.
-
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
diff --git a/colossalai/nn/lr_scheduler/torch.py b/colossalai/nn/lr_scheduler/torch.py
index e739084b6..abd0f4f39 100644
--- a/colossalai/nn/lr_scheduler/torch.py
+++ b/colossalai/nn/lr_scheduler/torch.py
@@ -10,7 +10,6 @@ from colossalai.registry import LR_SCHEDULERS
class LambdaLR(_LambdaLR):
"""Sets the learning rate of each parameter group to the initial lr
times a given function. When last_epoch=-1, sets initial lr as lr.
-
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@@ -33,7 +32,6 @@ class LambdaLR(_LambdaLR):
class MultiplicativeLR(_MultiplicativeLR):
"""Multiply the learning rate of each parameter group by the factor given
in the specified function. When last_epoch=-1, sets initial lr as lr
-
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@@ -58,7 +56,6 @@ class StepLR(_StepLR):
step_size epochs. Notice that such decay can happen simultaneously with
other changes to the learning rate from outside this scheduler. When
last_epoch=-1, sets initial lr as lr
-
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@@ -82,7 +79,6 @@ class StepLR(_StepLR):
class ExponentialLR(_ExponentialLR):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr
-
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
diff --git a/colossalai/nn/model/__init__.py b/colossalai/nn/model/__init__.py
index 5d5ccd96e..6ced17054 100644
--- a/colossalai/nn/model/__init__.py
+++ b/colossalai/nn/model/__init__.py
@@ -1,3 +1,3 @@
-from .base_model import BaseModel
-from .vanilla_resnet import VanillaResNet
-from .vision_transformer import *
+from .model_from_config import ModelFromConfig
+
+__all__ = ['ModelFromConfig']
diff --git a/colossalai/nn/model/base_model.py b/colossalai/nn/model/model_from_config.py
similarity index 92%
rename from colossalai/nn/model/base_model.py
rename to colossalai/nn/model/model_from_config.py
index cbe38fefa..24903ca36 100644
--- a/colossalai/nn/model/base_model.py
+++ b/colossalai/nn/model/model_from_config.py
@@ -8,10 +8,10 @@ import torch.nn as nn
from colossalai.builder import build_layer
-class BaseModel(nn.Module, ABC):
+class ModelFromConfig(nn.Module, ABC):
def __init__(self):
- super(BaseModel, self).__init__()
+ super(ModelFromConfig, self).__init__()
self.layers = nn.ModuleList()
self.layers_cfg = []
@@ -32,7 +32,6 @@ class BaseModel(nn.Module, ABC):
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
-
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
diff --git a/colossalai/nn/model/vanilla_resnet/__init__.py b/colossalai/nn/model/vanilla_resnet/__init__.py
deleted file mode 100644
index 1740de7dc..000000000
--- a/colossalai/nn/model/vanilla_resnet/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .resnet import VanillaResNet
-
-__all__ = ['VanillaResNet']
diff --git a/colossalai/nn/model/vanilla_resnet/resnet.py b/colossalai/nn/model/vanilla_resnet/resnet.py
deleted file mode 100644
index 905889649..000000000
--- a/colossalai/nn/model/vanilla_resnet/resnet.py
+++ /dev/null
@@ -1,163 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from typing import List, Optional
-
-import torch
-import torch.nn as nn
-from torch import Tensor
-
-from colossalai.registry import LAYERS
-from colossalai.registry import MODELS
-from ..base_model import BaseModel
-
-
-@MODELS.register_module
-class VanillaResNet(BaseModel):
- """ResNet from
- `"Deep Residual Learning for Image Recognition" `_.
- """
-
- def __init__(
- self,
- num_cls: int,
- block_type: str,
- layers: List[int],
- norm_layer_type: str = 'BatchNorm2d',
- in_channels: int = 3,
- groups: int = 1,
- width_per_group: int = 64,
- zero_init_residual: bool = False,
- replace_stride_with_dilation: Optional[List[bool]] = None,
- dilations=(1, 1, 1, 1)
- ) -> None:
- super().__init__()
-
- self.inplanes = 64
- self.zero_init_residual = zero_init_residual
- self.blocks = layers
- self.block_expansion = LAYERS.get_module(block_type).expansion
- self.dilations = dilations
- self.reslayer_common_cfg = dict(
- type='ResLayer',
- block_type=block_type,
- norm_layer_type=norm_layer_type,
- groups=groups,
- base_width=width_per_group
- )
-
- if replace_stride_with_dilation is None:
- # each element in the tuple indicates if we should replace
- # the 2x2 stride with a dilated convolution instead
- replace_stride_with_dilation = [False, False, False]
-
- if len(replace_stride_with_dilation) != 3:
- raise ValueError("replace_stride_with_dilation should be None "
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
-
- self.layers_cfg = [
- # conv1
- dict(type='Conv2d',
- in_channels=in_channels,
- out_channels=self.inplanes,
- kernel_size=7,
- stride=2,
- padding=3,
- bias=False),
- # bn1
- dict(
- type=norm_layer_type,
- num_features=self.inplanes
- ),
- # relu
- dict(
- type='ReLU',
- inplace=True
- ),
- # maxpool
- dict(
- type='MaxPool2d',
- kernel_size=3,
- stride=2,
- padding=1
- ),
- # layer 1
- dict(
- inplanes=self.inplanes,
- planes=64,
- blocks=self.blocks[0],
- dilation=self.dilations[0],
- **self.reslayer_common_cfg
- ),
- # layer 2
- dict(
- inplanes=64 * self.block_expansion,
- planes=128,
- blocks=self.blocks[1],
- stride=2,
- dilate=replace_stride_with_dilation[0],
- dilation=self.dilations[1],
- **self.reslayer_common_cfg
- ),
- # layer 3
- dict(
- inplanes=128 * self.block_expansion,
- planes=256,
- blocks=layers[2],
- stride=2,
- dilate=replace_stride_with_dilation[1],
- dilation=self.dilations[2],
- **self.reslayer_common_cfg
- ),
- # layer 4
- dict(
- inplanes=256 * self.block_expansion,
- planes=512,
- blocks=layers[3], stride=2,
- dilate=replace_stride_with_dilation[2],
- dilation=self.dilations[3],
- **self.reslayer_common_cfg
- ),
- # avg pool
- dict(
- type='AdaptiveAvgPool2d',
- output_size=(1, 1)
- ),
- # flatten
- dict(
- type='LambdaWrapper',
- func=lambda mod, x: torch.flatten(x, 1)
- ),
- # linear
- dict(
- type='Linear',
- in_features=512 * self.block_expansion,
- out_features=num_cls
- )
- ]
-
- def forward(self, x: Tensor):
- for layer in self.layers:
- x = layer(x)
- return x,
-
- def init_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(
- m.weight, mode='fan_out', nonlinearity='relu')
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- # Zero-initialize the last BN in each residual branch,
- # so that the residual branch starts with zeros, and each residual block behaves like an identity.
- # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
- if self.zero_init_residual:
- for m in self.modules():
- if isinstance(m, LAYERS.get_module('ResNetBottleneck')):
- # type: ignore[arg-type]
- nn.init.constant_(m.bn3.weight, 0)
- elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')):
- # type: ignore[arg-type]
- nn.init.constant_(m.bn2.weight, 0)
diff --git a/colossalai/nn/model/vision_transformer/__init__.py b/colossalai/nn/model/vision_transformer/__init__.py
deleted file mode 100644
index ab9d7e640..000000000
--- a/colossalai/nn/model/vision_transformer/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .vision_transformer import VisionTransformerFromConfig
-
-__all__ = ['VisionTransformerFromConfig']
diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py
index f9993c470..c084c5c86 100644
--- a/colossalai/nn/optimizer/__init__.py
+++ b/colossalai/nn/optimizer/__init__.py
@@ -1,14 +1,10 @@
-from .fp16_optimizer import FP16Optimizer
+from .colossalai_optimizer import ColossalaiOptimizer
from .fused_adam import FusedAdam
from .fused_lamb import FusedLAMB
from .fused_sgd import FusedSGD
from .lamb import Lamb
from .lars import Lars
-from .zero_redundancy_optimizer_level_1 import ZeroRedundancyOptimizer_Level_1
-from .zero_redundancy_optimizer_level_2 import ZeroRedundancyOptimizer_Level_2
-from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
__all__ = [
- 'ZeroRedundancyOptimizer_Level_1', 'ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3',
- 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'FP16Optimizer', 'Lars'
+ 'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars'
]
diff --git a/colossalai/nn/optimizer/_utils.py b/colossalai/nn/optimizer/_utils.py
deleted file mode 100644
index 6cd92bb38..000000000
--- a/colossalai/nn/optimizer/_utils.py
+++ /dev/null
@@ -1,168 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch
-from torch._six import inf
-
-try:
- import colossal_C
-except:
- print('Colossalai should be built with cuda extension to use the FP16 optimizer')
-
-from ..multi_tensor_apply import multi_tensor_applier
-
-from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-
-
-def is_model_parallel_parameter(p):
- return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
-
-
-def _calc_l2_norm(grads):
- norm = 0.0
- if len(grads) > 0:
- dummy_overflow_buf = torch.cuda.IntTensor([0])
- norm, _ = multi_tensor_applier(
- colossal_C.multi_tensor_l2norm,
- dummy_overflow_buf,
- [grads],
- False # no per-parameter norm
- )
- return norm
-
-
-def _calc_lp(grads, norm_type):
- norm = 0.0
- for grad in grads:
- grad_norm = torch.norm(grad, norm_type)
- norm += grad_norm ** norm_type
- return norm
-
-# ======== Gradient Clipping =========
-
-
-def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
- """Clips gradient norm of an iterable of parameters whose gradients
- are in fp32.
-
- This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
- added functionality to handle model parallel parameters. Note that
- the gradients are modified in place.
-
- Arguments:
- parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
- single Tensor that will have gradients normalized
- max_norm (float or int): max norm of the gradients
- norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
- infinity norm.
-
- Returns:
- Total norm of the parameters (viewed as a single vector).
- """
-
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
-
- # Filter parameters based on:
- # - grad should not be none
- # - parameter should not be shared
- # - should not be a replica due to tensor model parallelism
- params = []
- for param in parameters:
- if param.grad is not None:
- # Make sure the grads are in fp32
- assert param.grad.type() == 'torch.cuda.FloatTensor'
- params.append(param)
- # Norm parameters.
- max_norm = float(max_norm)
- norm_type = float(norm_type)
-
- # Calculate norm.
- if norm_type == inf:
- total_norm = max(p.grad.data.abs().max() for p in params)
- total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- if gpc.is_initialized(ParallelMode.TENSOR):
- # Take max across all model-parallel GPUs.
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX,
- group=gpc.get_group(ParallelMode.TENSOR))
- total_norm = total_norm_cuda[0].item()
- else:
- tensor_parallel_grads = []
- no_tensor_parallel_grads = []
- for p in params:
- if is_model_parallel_parameter(p):
- tensor_parallel_grads.append(p.grad.data)
- else:
- no_tensor_parallel_grads.append(p.grad.data)
- if norm_type == 2.0:
- tensor_parallel_norm = _calc_l2_norm(
- tensor_parallel_grads) ** norm_type
- no_tensor_parallel_norm = _calc_l2_norm(
- no_tensor_parallel_grads) ** norm_type
- else:
- tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
- no_tensor_parallel_grads = _calc_lp(
- no_tensor_parallel_grads, norm_type)
- if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
- # Sum across all model-parallel GPUs.
- torch.distributed.all_reduce(tensor_parallel_norm,
- op=torch.distributed.ReduceOp.SUM,
- group=gpc.get_group(ParallelMode.TENSOR))
- total_norm = (tensor_parallel_norm +
- no_tensor_parallel_norm) ** (1.0 / norm_type)
- if type(total_norm) == 'torch.cuda.FloatTensor':
- total_norm = total_norm.item()
-
- # Scale.
- clip_coeff = max_norm / (total_norm + 1.0e-6)
- if clip_coeff < 1.0:
- grads = [p.grad.detach() for p in params]
- dummy_overflow_buf = torch.cuda.IntTensor([0])
- multi_tensor_applier(colossal_C.multi_tensor_scale,
- dummy_overflow_buf,
- [grads, grads],
- clip_coeff)
-
- return total_norm
-
-
-def count_zeros_fp32(parameters):
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
-
- # Filter parameters based on:
- # - grad should not be none
- # - parameter should not be shared
- # - should not be a replica due to tensor model parallelism
- total_num_zeros = 0.0
- for param in parameters:
- grad_not_none = param.grad is not None
- is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
- if grad_not_none and is_not_tp_duplicate:
- grad = param.grad.detach()
- num_zeros = grad.numel() - torch.count_nonzero(grad)
- total_num_zeros = num_zeros + total_num_zeros
-
- # Sum across all model-parallel GPUs.
- torch.distributed.all_reduce(total_num_zeros,
- op=torch.distributed.ReduceOp.SUM,
- group=gpc.get_group(ParallelMode.TENSOR))
- total_num_zeros = total_num_zeros.item()
-
- return total_num_zeros
-
-
-def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
- for attr in TENSOR_PARALLEL_ATTRIBUTES:
- if hasattr(src_tensor, attr):
- val = getattr(src_tensor, attr)
- setattr(dst_tensor, attr, val)
-
-
-def param_is_not_tensor_parallel_duplicate(param):
- return (hasattr(param, IS_TENSOR_PARALLEL) and
- getattr(param, IS_TENSOR_PARALLEL)) or (
- gpc.get_local_rank(ParallelMode.TENSOR) == 0)
diff --git a/colossalai/nn/optimizer/colossalai_optimizer.py b/colossalai/nn/optimizer/colossalai_optimizer.py
new file mode 100644
index 000000000..fb0c43903
--- /dev/null
+++ b/colossalai/nn/optimizer/colossalai_optimizer.py
@@ -0,0 +1,47 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.optim import Optimizer
+from colossalai.utils import clip_grad_norm_fp32
+
+
+class ColossalaiOptimizer(Optimizer):
+
+ def __init__(self, optim: Optimizer):
+ self.optim = optim
+
+ @property
+ def param_groups(self):
+ return self.optim.param_groups
+
+ @property
+ def defaults(self):
+ return self.optim.defaults
+
+ def add_param_group(self, *args, **kwargs):
+ return self.optim.add_param_group(*args, **kwargs)
+
+ def step(self, *args, **kwargs):
+ return self.optim.step(*args, **kwargs)
+
+ def zero_grad(self, *args, **kwargs):
+ self.optim.zero_grad(*args, **kwargs)
+
+ def load_state_dict(self, *args, **kwargs):
+ self.optim.load_state_dict(*args, **kwargs)
+
+ def state_dict(self):
+ return self.optim.state_dict()
+
+ def backward(self, loss: Tensor):
+ loss.backward()
+
+ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
+ torch.autograd.backward(tensors=tensor, grad_tensors=grad)
+
+ def clip_grad_norm(self, model: nn.Module, max_norm: float):
+ if max_norm > 0.0:
+ clip_grad_norm_fp32(model.parameters(), max_norm)
diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py
index 5ab31b363..8bcd3841a 100644
--- a/colossalai/nn/optimizer/fused_adam.py
+++ b/colossalai/nn/optimizer/fused_adam.py
@@ -2,7 +2,7 @@
import torch
from colossalai.registry import OPTIMIZERS
-from ..multi_tensor_apply import multi_tensor_applier
+from colossalai.utils import multi_tensor_applier
@OPTIMIZERS.register_module
diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py
index 14b1167a9..8a340a9f3 100644
--- a/colossalai/nn/optimizer/fused_lamb.py
+++ b/colossalai/nn/optimizer/fused_lamb.py
@@ -2,7 +2,7 @@
import torch
from colossalai.registry import OPTIMIZERS
-from ..multi_tensor_apply import multi_tensor_applier
+from colossalai.utils import multi_tensor_applier
@OPTIMIZERS.register_module
diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py
index 3950c40be..4986aa5f5 100644
--- a/colossalai/nn/optimizer/fused_sgd.py
+++ b/colossalai/nn/optimizer/fused_sgd.py
@@ -3,7 +3,7 @@ import torch
from torch.optim.optimizer import Optimizer, required
from colossalai.registry import OPTIMIZERS
-from ..multi_tensor_apply import multi_tensor_applier
+from colossalai.utils import multi_tensor_applier
@OPTIMIZERS.register_module
diff --git a/colossalai/nn/optimizer/zero_redundancy_optimizer_level_1.py b/colossalai/nn/optimizer/zero_redundancy_optimizer_level_1.py
deleted file mode 100644
index 05848f1dd..000000000
--- a/colossalai/nn/optimizer/zero_redundancy_optimizer_level_1.py
+++ /dev/null
@@ -1,707 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import math
-from collections import defaultdict
-
-import torch
-import torch.distributed as dist
-from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-from torch.optim import Optimizer
-
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.registry import OPTIMIZER_WRAPPERS
-from colossalai.utils import get_current_device, print_rank_0
-
-
-def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
- sub_partition_high_limit = (sub_partition_id + 1) * sub_partition_size
- if sub_partition_high_limit <= flattened_lean_size:
- return 0
- else:
- return min(sub_partition_size, sub_partition_high_limit - flattened_lean_size)
-
-
-def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_count):
- group_paddings = []
- flattened_size = sum([tensor.numel() for tensor in tensor_list])
- for i in range(sub_partition_count):
- padding = get_alignment_padding(flattened_size, i, sub_partition_size)
- group_paddings.append(padding)
-
- return group_paddings
-
-
-def _single_range_check(current_index, start_index, end_index, tensor_size):
- offset = 0
- if (current_index >= start_index) and (current_index < end_index):
- # Fully inside bounds
- return True, offset
- elif (start_index > current_index) and (start_index < (current_index + tensor_size)):
- # Partially contained, compute offset
- offset = start_index - current_index
- return True, offset
- else:
- return False, offset
-
-
-def _range_check(current_index, element_intervals, tensor_size):
- results = []
- for comm_idx, interval in enumerate(element_intervals):
- start_index, end_index = interval
- contained, offset = _single_range_check(
- current_index, start_index, end_index, tensor_size)
- if contained:
- results.append((contained, offset, comm_idx))
- if len(results) == 0:
- return [(False, 0, -1)]
- return results
-
-
-@OPTIMIZER_WRAPPERS.register_module
-class ZeroRedundancyOptimizer_Level_1(Optimizer):
- """
- ZeroRedundancyOptimizer_Level_1 designed to reduce the memory footprint
- required for training large deep learning models.
-
- For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
- https://arxiv.org/abs/1910.02054
-
- This version aligns with stage-1 in the paper above.
- """
-
- def __init__(self,
- init_optimizer: Optimizer,
- dp_parallel_mode: ParallelMode = ParallelMode.DATA,
- max_elements_per_comm=5e8,
- verbose=False
- ):
- # TODO: this class does not work with fp16 AMP_TYPE.PARALLEL, fix it
- assert get_current_device() != 'cpu', 'ZeRO optimizer cannot be used on CPU only'
-
- self.flatten = _flatten_dense_tensors
- self.unflatten = _unflatten_dense_tensors
- self.optimizer = init_optimizer
- self.dp_parallel_mode = dp_parallel_mode
- self.verbose = verbose
-
- # for compatibility with pytorch optim
- self.defaults = init_optimizer.defaults
-
- # param flattened by groups
- self._param_groups = []
- self._param_groups_flat = []
-
- # parallel_sub_partitioned_fp16_groups[group-idx] -> [comm-ids] -> [rank-ids]
- self.parallel_sub_partitioned_groups = []
- # same underlying data as above but viewed as: [groups] -> [rank-ids] -> [comm-ids]
- self.parallel_comm_sub_partitioned_groups = []
-
- # param partition info
- # parameters in each group that will not be updated by this process directly
- self.params_not_local = []
-
- # parameters that will be updated by this process directly
- self.params_in_rank_sub_partitions = []
-
- # parameter offsets for parameters in sub-partitions. Parameter
- # boundaries may not align with sub-partition boundaries
- # so we need to keep track of the offsets
- self.params_in_rank_sub_partitions_offsets = []
-
- # number of elements per sub-partition in each group
- self.sub_partition_sizes = []
-
- # number of communication intervals for each group
- self.num_comm_intervals_per_group = []
-
- self.local_rank = gpc.get_local_rank(self.dp_parallel_mode)
- self.partition_count = self.world_size = gpc.get_world_size(
- self.dp_parallel_mode)
-
- self.group_paddings = []
- self.default_device = self.optimizer.param_groups[0]['params'][0].device
-
- # max elems per param group
- self.max_elems_per_comm = []
-
- # loop to deal with groups
- for i, param_group in enumerate(self.optimizer.param_groups):
- # push this group to list before modify
- self._param_groups.append(param_group['params'])
-
- # calculate best max elements per comm based to minimize padding
- self.max_elems_per_comm.append(
- self.best_max_elems_per_comm(
- num_elements=sum(t.numel() for t in self._param_groups[i]),
- max_elements_per_comm=max_elements_per_comm
- )
- )
-
- # flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
- # RS: create aligned sub-partitions
- flat_aligned_params = self.flatten_dense_tensors_sub_partition_aligned(
- tensor_list=self._param_groups[i],
- max_elements_per_comm=self.max_elems_per_comm[i],
- )
- self._param_groups_flat.append(flat_aligned_params)
-
- updated_params = self.unflatten(self._param_groups_flat[i],
- self._param_groups[i])
- for p, q in zip(self._param_groups[i], updated_params):
- p.data = q.data
-
- # divide the flat weights into near equal partition equal to the data parallel degree
- # each process will compute on a different part of the partition
- # RS: split into two layer list -> [comm-id] -> [sub-partitions per rank]
- comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
- self.get_data_parallel_sub_partitions(
- tensor=self._param_groups_flat[i],
- max_elements_per_comm=self.max_elems_per_comm[i],
- )
- self.parallel_comm_sub_partitioned_groups.append(
- comm_partitions) # comm -> rank
- self.parallel_sub_partitioned_groups.append(
- dp_sub_partitions) # rank -> comm
- self.sub_partition_sizes.append(sub_partition_size)
- self.num_comm_intervals_per_group.append(num_comm_intervals)
-
- # Compute sub_partition paddings
- sub_partition_paddings = get_group_alignment_padding(
- tensor_list=self._param_groups[i],
- sub_partition_size=sub_partition_size,
- sub_partition_count=num_comm_intervals * self.partition_count)
- self.group_paddings.append(sub_partition_paddings)
-
- # modify optimizer of have flat master weight
- param_group['params'] = self.parallel_sub_partitioned_groups[i][self.local_rank]
-
- # RS: divide up the sub-partitions and keep track of offsets for each param
- # partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
- params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local = self.get_all_sub_partition_info(
- tensor_list=self._param_groups[i],
- all_element_intervals=element_intervals,
- )
-
- self.params_in_rank_sub_partitions.append(
- params_in_rank_sub_partition)
- self.params_not_local.append(params_not_local)
- self.params_in_rank_sub_partitions_offsets.append(
- params_in_rank_sub_partitions_offsets)
-
- self.local_sub_partitions_of_groups = [
- group[self.local_rank] for group in self.parallel_sub_partitioned_groups]
- self._initialize_optimizer_states()
-
- @property
- def state(self):
- return self.optimizer.state
-
- @state.setter
- def state(self, value):
- self.optimizer.state = value
-
- @property
- def param_groups(self):
- # LSG: return the full param groups instead of local partitions
- # of the param groups for compatibility with torch.cuda.amp
- param_groups = []
-
- for group_id, group in enumerate(self.optimizer.param_groups):
- group_containing_all_param = {
- 'params': self._param_groups[group_id],
- **{k: v for k, v in group.items() if k != 'params'}
- }
- # LSG: for compatibility with unknown bug with lr scheduler
- # TODO: fix this
- group_containing_all_param.setdefault('initial_lr', group['lr'])
- param_groups.append(group_containing_all_param)
- return param_groups
-
- @param_groups.setter
- def param_groups(self, value):
- self.optimizer.param_groups = value
-
- def _initialize_optimizer_states(self):
- for group_idx, group in enumerate(self.local_sub_partitions_of_groups):
- for idx, sub_partition_param in enumerate(group):
- sub_partition_grad = torch.zeros(int(
- self.sub_partition_sizes[group_idx]),
- dtype=sub_partition_param.dtype).cuda()
- sub_partition_param.grad = sub_partition_grad
-
- self.optimizer.step()
-
- # LSG: comment out for compatibility with torch.cuda.amp
- # for group in self.local_sub_partitions_of_groups:
- # for idx, sub_partition_param in enumerate(group):
- # sub_partition_param.grad = None
-
- def best_max_elems_per_comm(self, num_elements, max_elements_per_comm):
- # if we use max-elems-per-comm as is, how many comm intervals will there be
- max_comm_intervals = math.ceil(num_elements / max_elements_per_comm)
- padding_for_max_comm = (max_elements_per_comm *
- max_comm_intervals) - num_elements
-
- # if we use 1 less comm interval how much extra comm padding would be required
- min_comm_intervals = num_elements // max_elements_per_comm
- if min_comm_intervals == 0:
- if self.verbose:
- print_rank_0(
- f'Using default max_elements_per_comm {max_elements_per_comm}')
- return max_elements_per_comm
-
- padding_for_min_comm = math.ceil(
- num_elements / (self.world_size * min_comm_intervals))
-
- # choose padding that uses least amount of overhead
- if padding_for_max_comm > padding_for_min_comm:
- new_max_elements_per_comm = padding_for_min_comm + max_elements_per_comm
- if self.verbose:
- print_rank_0(
- f'Updating max_elements_per_comm from {max_elements_per_comm} -> {new_max_elements_per_comm}')
- return new_max_elements_per_comm
- else:
- if self.verbose:
- print_rank_0(
- f'Using default max_elements_per_comm {max_elements_per_comm}')
- return max_elements_per_comm
-
- def get_data_parallel_sub_partitions(self,
- tensor,
- max_elements_per_comm,
- ):
- total_num_elements = tensor.numel()
-
- # if total elements is less than our max, revert to splitting into dp partitions
- max_elements_per_comm = min(total_num_elements, max_elements_per_comm)
- sub_partition_size = int(max_elements_per_comm // self.world_size)
-
- # Ensure partition alignment was done correctly
- num_sub_partitions = int(total_num_elements // sub_partition_size)
- assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format(total_num_elements,
- sub_partition_size)
-
- # Ensure comm interval alignment was done correctly.
- num_comm_intervals = int(num_sub_partitions // self.world_size)
- assert num_sub_partitions % self.world_size == 0, "{} % {} != 0".format(
- num_sub_partitions, self.world_size)
-
- if self.verbose:
- print_rank_0("**** partition info:")
- print_rank_0(f"\t total_num_elements={total_num_elements}")
- print_rank_0(f"\t world_size={self.world_size}")
- print_rank_0(f"\t max_elements_per_comm={max_elements_per_comm}")
- print_rank_0(f"\t sub_partition_size={sub_partition_size}")
- print_rank_0(f"\t num_sub_partitions={num_sub_partitions}")
- print_rank_0(f"\t num_comm_intervals={num_comm_intervals}")
- print_rank_0("****")
-
- # [comm_id] -> [rank]
- comm_partitions = []
- for _ in range(num_comm_intervals):
- comm_partitions.append([])
-
- start = 0
- comm_id = 0
- element_intervals = defaultdict(
- list) # [rank] -> [(start,end), (start,end), ...]
- for idx in range(num_sub_partitions):
- rank_id = idx % self.world_size
- sub_partition = tensor.narrow(
- 0, start, sub_partition_size).detach()
- element_intervals[rank_id].append(
- (start, start + sub_partition_size))
- comm_partitions[comm_id].append(sub_partition)
- start = start + sub_partition_size
- if rank_id == (self.world_size - 1):
- comm_id += 1
-
- # [rank] -> [comm_id]
- sub_partitions = []
- for _ in range(self.world_size):
- sub_partitions.append([])
- for comm_id, partitions in enumerate(comm_partitions):
- for rank_id, partition in enumerate(partitions):
- sub_partitions[rank_id].append(partition)
-
- return comm_partitions, sub_partitions, element_intervals, sub_partition_size, num_comm_intervals
-
- def get_all_sub_partition_info(self,
- tensor_list,
- all_element_intervals,
- ):
- params_not_local = []
-
- # [rank] -> [comm-id] -> [param/offset]
- params_in_rank_sub_partition = []
- params_in_rank_sub_partitions_offsets = []
-
- for rank in range(self.world_size):
- params_in_local_sub_partition = []
- local_sub_partition_offsets = []
- comm_tensor_list = []
- comm_offset_list = []
- current_index = 0
- prev_comm_idx = 0
- for iii, tensor in enumerate(tensor_list):
- tensor_size = tensor.numel()
- results_list = _range_check(current_index,
- all_element_intervals[rank],
- tensor_size)
- for contained, offset, comm_idx in results_list:
- if contained:
- if prev_comm_idx != comm_idx:
- params_in_local_sub_partition.append(
- comm_tensor_list)
- comm_tensor_list = []
- local_sub_partition_offsets.append(
- comm_offset_list)
- comm_offset_list = []
- comm_tensor_list.append(tensor)
- comm_offset_list.append(offset)
- prev_comm_idx = comm_idx
- elif rank == self.local_rank:
- params_not_local.append(tensor)
-
- current_index = current_index + tensor_size
-
- # assert len(comm_tensor_list) > 0
- # assert len(comm_offset_list) > 0
- params_in_local_sub_partition.append(comm_tensor_list)
- local_sub_partition_offsets.append(comm_offset_list)
-
- params_in_rank_sub_partition.append(params_in_local_sub_partition)
- params_in_rank_sub_partitions_offsets.append(
- local_sub_partition_offsets)
-
- return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local
-
- def get_flat_sub_partitions(self,
- comm_tensor_list,
- comm_param_offsets,
- sub_partition_size,
- dtype,
- default_device,
- num_comm_intervals=None,
- return_partition_params=False):
- partition_params = []
- final_param_offsets = []
- flat_sub_partitions = []
- for tensor_list, param_offsets in zip(comm_tensor_list, comm_param_offsets):
- flat_tensor_list = []
- current_size = 0
- my_offsets = []
- my_params = []
-
- for i, tensor in enumerate(tensor_list):
- if tensor.grad is None:
- tensor.grad = torch.zeros(tensor.size(),
- dtype=tensor.dtype,
- device=tensor.device)
- param = tensor
- tensor = tensor.grad
- num_elements = tensor.numel()
- tensor_offset = 0
-
- # we need to offset to get to the right element
- if i == 0 and param_offsets[i] > 0:
- tensor_offset = param_offsets[i]
- num_elements = num_elements - tensor_offset
-
- # We don't need all elements of the tensor if this tensor is
- # larger than we have space for in our curr sub-partition
- if num_elements > (sub_partition_size - current_size):
- num_elements = sub_partition_size - current_size
-
- # we need a narrow view of the tensor based on the tensor offset and number of elements that
- # we need from this tensor
- if tensor_offset > 0 or num_elements < tensor.numel():
- flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
- 0,
- int(tensor_offset),
- int(num_elements)).to(dtype))
- else:
- flat_tensor_list.append(tensor.to(dtype))
- my_params.append(param)
-
- # remember offset into partition and #elems for this tensor
- my_offsets.append((current_size, num_elements))
-
- current_size = current_size + num_elements
-
- # this means its the last partition and does not align with the dp boundary. We need to pad before flattening
- if current_size < sub_partition_size:
- my_offsets.append((None, None))
- my_params.append(None)
- if len(tensor_list) == 0:
- assert default_device != None
- flat_tensor_list.append(
- torch.zeros(int(sub_partition_size - current_size),
- dtype=dtype,
- device=default_device))
- else:
- flat_tensor_list.append(
- torch.zeros(int(sub_partition_size - current_size),
- dtype=dtype,
- device=tensor_list[0].device))
- partition_params.append(my_params) # flat_tensor_list)
- final_param_offsets.append(my_offsets)
- assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(
- len(flat_tensor_list), len(my_offsets))
- flat_sub_partitions.append(self.flatten(flat_tensor_list))
- if num_comm_intervals is not None and len(
- flat_sub_partitions) < num_comm_intervals:
- # print("padding w. sub partitions to ensure uniform communication")
- device = flat_sub_partitions[0].device
- for _ in range(num_comm_intervals - len(flat_sub_partitions)):
- flat_sub_partitions.append(
- torch.zeros(int(sub_partition_size),
- dtype=dtype,
- device=device))
- partition_params.append([None])
- final_param_offsets.append([(None, None)])
-
- if return_partition_params:
- assert len(flat_sub_partitions) == len(partition_params)
- assert len(partition_params) == len(final_param_offsets), "{} {}".format(len(partition_params),
- len(final_param_offsets))
- return flat_sub_partitions, partition_params, final_param_offsets
- return flat_sub_partitions
-
- def zero_grad(self, set_grads_to_None=False):
- """
- Zero FP16 parameter grads.
- """
- # FP32 grad should never exist.
- # For speed, set model fp16 grad to None by default
- for group in self._param_groups:
- for p in group:
- if set_grads_to_None:
- p.grad = None
- else:
- if p.grad is not None:
- p.grad.detach_()
- p.grad.zero_()
-
- def free_grad_in_param_list(self, param_list):
- for p in param_list:
- if isinstance(p, list):
- for _p in p:
- _p.grad = None
- else:
- p.grad = None
-
- def flatten_dense_tensors_sub_partition_aligned(self,
- tensor_list,
- max_elements_per_comm
- ):
- assert max_elements_per_comm >= self.world_size, f"max_elements_per_comm {max_elements_per_comm} < dp {self.world_size}"
-
- num_elements = sum(t.numel() for t in tensor_list)
-
- # Compute aligned partition size based on parameter count
- aligned_param_partition_size = math.ceil(
- num_elements / self.world_size)
-
- # Compute aligned partition size based on communication size
- aligned_comm_partition_size = int(
- max_elements_per_comm // self.world_size)
-
- if aligned_param_partition_size <= aligned_comm_partition_size:
- sub_partition_count = 1
- sub_partition_size = aligned_param_partition_size
- else:
- sub_partition_count = math.ceil(aligned_param_partition_size /
- aligned_comm_partition_size)
- sub_partition_size = aligned_comm_partition_size
-
- # Compute required padding for alignment to dp and max_elements_per_comm
- padding = (sub_partition_count * sub_partition_size *
- self.world_size) - num_elements
-
- if self.verbose:
- print_rank_0(
- f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}")
- print_rank_0(
- f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}")
-
- if padding == 0:
- aligned_tensor_list = tensor_list
- else:
- pad_tensor = torch.zeros(padding,
- device=tensor_list[0].device,
- dtype=tensor_list[0].dtype)
- aligned_tensor_list = tensor_list + [pad_tensor]
-
- flat_tensors = self.flatten(aligned_tensor_list)
- return flat_tensors
-
- # def reduce_gradients(self):
- # # LSG: this reduce gradients method no longer works
- # # after code change, please use DataParallelGradientHandler instead
- #
- # world_size = gpc.get_world_size(self.parallel_mode)
- # local_rank = gpc.get_local_rank(self.parallel_mode)
- #
- # for i, group in enumerate(self._param_groups):
- # num_comm_intervals = self.num_comm_intervals_per_group[i]
- # all_sub_partitions = []
- # for rank in range(world_size):
- # # gsp is list of partitions indexed by comm_idx
- # grad_sub_partitions = self.get_flat_sub_partitions(
- # comm_tensor_list=self.params_in_rank_sub_partitions[i][rank],
- # comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][rank],
- # dtype=self.local_sub_partitions_of_groups[i][0].dtype,
- # default_device=self.default_device,
- # sub_partition_size=self.sub_partition_sizes[i],
- # num_comm_intervals=self.num_comm_intervals_per_group[i])
- # all_sub_partitions.append(grad_sub_partitions)
- #
- # assert len(grad_sub_partitions) == num_comm_intervals
- #
- # local_comm_partitions = []
- # for comm_idx in range(num_comm_intervals):
- # single_comm_all_partitions = []
- # for rank in range(world_size):
- # single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx])
- #
- # for partition in single_comm_all_partitions:
- # partition.div_(world_size)
- #
- # dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
- # input_list=single_comm_all_partitions,
- # group=gpc.get_group(self.parallel_mode))
-
- def step(self, closure=None):
- local_sub_partitions_grad_groups = []
-
- for i, group in enumerate(self._param_groups):
- # RS: update free grads w.r.t. sub partitions
- # free gradients for all the parameters that are not updated by this process
- self.free_grad_in_param_list(self.params_not_local[i])
-
- # create flat gradient partitions for parameters updated by this process
- local_grad_sub_partitions = self.get_flat_sub_partitions(
- comm_tensor_list=self.params_in_rank_sub_partitions[i][self.local_rank],
- comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][self.local_rank],
- sub_partition_size=self.sub_partition_sizes[i],
- dtype=self.local_sub_partitions_of_groups[i][0].dtype,
- num_comm_intervals=self.num_comm_intervals_per_group[i],
- default_device=self.default_device)
-
- # RS: update all our local params with sub-partition grads
- for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_groups[i]):
- sub_partition_param.grad = local_grad_sub_partitions[idx]
-
- # RS: update free grads for sub-partitions
- # release all the gradient since we have already created a necessary copy in dp_grad_partition
- self.free_grad_in_param_list(
- self.params_in_rank_sub_partitions[i][self.local_rank])
-
- local_sub_partitions_grad_groups.append(local_grad_sub_partitions)
-
- if closure is None:
- loss = self.optimizer.step()
- else:
- loss = self.optimizer.step(closure=closure)
-
- # RS: clear our sub partition grads
- # LSG: not needed as amp is used instead
- # get rid of the fp32 gradients. Not needed anymore
- # for group in self.local_sub_partitions_of_groups:
- # for idx, sub_partition_param in enumerate(group):
- # sub_partition_param.grad = None
-
- # RS: all_gather/broadcast sub-partitions in separate comm calls
- # gather the updated weights from everyone
- for all_sub_partitions in self.parallel_comm_sub_partitioned_groups:
- for comm_id, sub_partitions in enumerate(all_sub_partitions):
- dist.all_gather(sub_partitions,
- sub_partitions[self.local_rank],
- group=gpc.get_group(self.dp_parallel_mode))
-
- # TODO: we probably don't need this? just to be safe
- for i in range(len(self._param_groups)):
- updated_params = self.unflatten(self._param_groups_flat[i],
- self._param_groups[i])
- for p, q in zip(self._param_groups[i], updated_params):
- p.data = q.data
-
- return loss
-
- def _rigid_state_dict(self):
- """Returns a dict that can be loaded for continued training with same DP degree
-
- Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
- This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
- of the contained Pytorch optimizer.
-
- Example::
-
- checkpoint = {}
- checkpoint['model'] = model.state_dict()
- checkpoint['optimizer'] = optimizer.state_dict()
- torch.save(checkpoint, "saved.pth")
- """
- state_dict = {}
- for k, v in self.optimizer.state_dict().items():
- state_dict[k] = v
- state_dict[
- 'local_sub_partitions_of_groups'] = self.local_sub_partitions_of_groups
- return state_dict
-
- def state_dict(self):
- """
- Returns a dict containing the current state of this Optimizer instance.
- This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
- of the contained Pytorch optimizer.
-
- Example::
-
- checkpoint = {}
- checkpoint['model'] = model.state_dict()
- checkpoint['optimizer'] = optimizer.state_dict()
- torch.save(checkpoint, "saved.pth")
- """
- return self._rigid_state_dict()
-
- def load_state_dict(self,
- state_dict,
- load_optimizer_states=True,
- ):
- """
- Loads a state_dict created by an earlier call to state_dict().
- If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
- whose parameters in turn came from ``model``, it is expected that the user
- will call ``model.load_state_dict()`` before
- ``fp16_optimizer_instance.load_state_dict()`` is called.
-
- Example::
-
- model = torch.nn.Linear(D_in, D_out).cuda().half()
- optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
- optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
- ...
- checkpoint = torch.load("saved.pth")
- model.load_state_dict(checkpoint['model'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- """
- self._rigid_load_state_dict(
- state_dict,
- load_optimizer_states)
-
- def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
- # I think it should actually be ok to reload the optimizer before the model.
- state_dict_ = state_dict.copy()
- local_sub_partitions_of_groups = state_dict_.pop(
- 'local_sub_partitions_of_groups')
-
- if load_optimizer_states:
- self.optimizer.load_state_dict(state_dict_)
-
- for curr_group, saved_group in zip(self.local_sub_partitions_of_groups,
- local_sub_partitions_of_groups):
- for curr_param, saved_param in zip(curr_group, saved_group):
- curr_param.data.copy_(saved_param.data)
diff --git a/colossalai/registry/__init__.py b/colossalai/registry/__init__.py
index 1de1c98ae..492b278a4 100644
--- a/colossalai/registry/__init__.py
+++ b/colossalai/registry/__init__.py
@@ -2,7 +2,8 @@ import torch.distributed.optim as dist_optim
import torch.nn as nn
import torch.optim as optim
import torchvision.models as tv_models
-from torchvision.transforms import transforms
+import torchvision.datasets as tv_datasets
+from torchvision import transforms
from .registry import Registry
@@ -10,14 +11,12 @@ LAYERS = Registry('layers', third_party_library=[nn])
LOSSES = Registry('losses')
MODELS = Registry('models', third_party_library=[tv_models])
OPTIMIZERS = Registry('optimizers', third_party_library=[optim, dist_optim])
-OPTIMIZER_WRAPPERS = Registry('optimizer_wrappers')
-DATASETS = Registry('datasets')
+DATASETS = Registry('datasets', third_party_library=[tv_datasets])
DIST_GROUP_INITIALIZER = Registry('dist_group_initializer')
GRADIENT_HANDLER = Registry('gradient_handler')
LOSSES = Registry('losses', third_party_library=[nn])
HOOKS = Registry('hooks')
TRANSFORMS = Registry('transforms', third_party_library=[transforms])
-PIPE_ALLOC_POLICY = Registry('pipeline_allocation_policy')
-SAMPLERS = Registry('samplers')
+DATA_SAMPLERS = Registry('data_samplers')
LR_SCHEDULERS = Registry('lr_schedulers')
SCHEDULE = Registry('schedules')
diff --git a/colossalai/trainer/__init__.py b/colossalai/trainer/__init__.py
index 57f7b7495..84e53dc4e 100644
--- a/colossalai/trainer/__init__.py
+++ b/colossalai/trainer/__init__.py
@@ -1,5 +1,3 @@
from ._trainer import Trainer
-from .hooks import *
-from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D, LearningRate
-__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D', 'LearningRate']
+__all__ = ['Trainer']
diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py
index 96a82d995..6cce0a3e4 100644
--- a/colossalai/trainer/_trainer.py
+++ b/colossalai/trainer/_trainer.py
@@ -2,18 +2,21 @@
# -*- encoding: utf-8 -*-
from typing import Union, List
+from colossalai import engine
+from colossalai.context.parallel_mode import ParallelMode
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm
-from colossalai.builder import build_hooks
+from colossalai.core import global_context as gpc
from colossalai.engine import Engine
-from colossalai.logging import get_global_dist_logger
-from colossalai.nn.data import DataParallelSampler
+from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule
+from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
+from .hooks import BaseHook
class Trainer:
@@ -31,8 +34,9 @@ class Trainer:
def __init__(self,
engine: Engine,
- verbose: bool = False,
- timer: MultiTimer = None):
+ schedule: BaseSchedule = None,
+ timer: MultiTimer = None,
+ logger: DistributedLogger = None):
# training-ralated params
self._engine = engine
self._max_epochs = 0
@@ -42,8 +46,8 @@ class Trainer:
self._steps_per_epoch = 0
# misc params
- self._logger = get_global_dist_logger()
- self._verbose = verbose
+ self._logger = logger
+ self._verbose = logger is not None
# hooks can store states in this dict, and could be consumed by other hooks
self.states = dict()
@@ -54,6 +58,15 @@ class Trainer:
# multi-timer for time benchmarking
self._timer = timer
+ # set schedule which specifies the training iteration for the engine
+ if schedule is None:
+ schedule = NonPipelineSchedule()
+ if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
+ assert not isinstance(schedule, NonPipelineSchedule), \
+ 'NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead.'
+ self._schedule = schedule
+ self._schedule.pre_processing(engine)
+
@property
def cur_epoch(self):
"""Returns the index of the current epoch.
@@ -89,9 +102,9 @@ class Trainer:
def engine(self):
return self._engine
- @engine.setter
- def engine(self, engine_: Engine):
- self._engine = engine_
+ @property
+ def schedule(self):
+ return self._schedule
def _set_current_step(self, epoch: int):
"""Sets current step number.
@@ -129,9 +142,9 @@ class Trainer:
# Only after iter hook will receive output
for hook in self.hooks:
if output is None:
- getattr(hook, func)()
+ getattr(hook, func)(self)
else:
- getattr(hook, func)(*output)
+ getattr(hook, func)(self, *output)
@staticmethod
def _should_display_progress(display_progress: bool):
@@ -143,12 +156,6 @@ class Trainer:
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False):
- # set sampler epoch
- if epoch is not None and \
- hasattr(train_dataloader, 'sampler') and \
- isinstance(train_dataloader.sampler, DataParallelSampler):
- train_dataloader.sampler.set_epoch(epoch)
-
# set training state
self._engine.train()
data_iter = iter(train_dataloader)
@@ -159,20 +166,17 @@ class Trainer:
else:
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
- # train 1 epoch
self._call_hooks('before_train_epoch')
self._call_timer(action='start', item='train-epoch')
for i in progress:
self._call_hooks('before_train_iter')
self._call_timer(action='start', item='train-step')
- if i == self._steps_per_epoch - 1:
- is_last_iteration = True
- else:
- is_last_iteration = False
-
# run 1 training step
- logits, label, loss = self._engine.step(data_iter, is_last_iteration)
+ self.engine.zero_grad()
+ logits, label, loss = self.schedule.forward_backward_step(
+ self.engine, data_iter, forward_only=False, return_loss=True)
+ self.engine.step()
self._call_timer(action='stop', item='train-step', keep_in_history=True)
self._call_hooks('after_train_iter', output=(logits, label, loss))
@@ -197,32 +201,33 @@ class Trainer:
num_steps = len(test_dataloader)
self._call_hooks('before_test')
- with torch.no_grad():
- # prepare progress bar
- progress = range(num_steps)
- if display_progress:
- desc = 'Evaluation'
- if epoch is not None:
- desc = '[Epoch %d val]' % epoch
- progress = tqdm(progress, desc=desc)
+ # prepare progress bar
+ progress = range(num_steps)
+ if display_progress:
+ desc = 'Evaluation'
+ if epoch is not None:
+ desc = '[Epoch %d val]' % epoch
+ progress = tqdm(progress, desc=desc)
- self._call_hooks('before_test_epoch')
- self._call_timer(action='start', item='test-epoch')
+ self._call_hooks('before_test_epoch')
+ self._call_timer(action='start', item='test-epoch')
+ with torch.no_grad():
for _ in progress:
self._call_hooks('before_test_iter')
self._call_timer(action='start', item='test-step')
- logits, label, loss = self._engine.step(data_iter, return_loss=True)
+ logits, label, loss = self.schedule.forward_backward_step(
+ self.engine, data_iter, forward_only=True, return_loss=True)
self._call_timer(action='stop', item='test-step', keep_in_history=True)
self._call_hooks('after_test_iter',
output=(logits, label, loss))
- self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
- self._call_hooks('after_test_epoch')
+ self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
+ self._call_hooks('after_test_epoch')
self._call_hooks('after_test')
self._call_timer(action='reset', item='test-step')
self._call_timer(action='reset', item='test-epoch')
def _exceed_max_step(self):
- return self._max_steps is not None and self._cur_step > self._max_steps
+ return self._max_steps is not None and self._cur_step >= self._max_steps
def fit(self,
train_dataloader: DataLoader,
@@ -230,7 +235,7 @@ class Trainer:
max_steps: int = None,
test_dataloader: DataLoader = None,
test_interval: int = 1,
- hooks_cfg: dict = None,
+ hooks: List[BaseHook] = None,
display_progress: bool = False,
):
"""Trains the model to fit training data.
@@ -253,7 +258,7 @@ class Trainer:
"""
# set epochs and steps, consider gradient accumulation
- self._steps_per_epoch = len(train_dataloader) // self._engine.gradient_accumulation
+ self._steps_per_epoch = len(train_dataloader)
self._max_steps = max_steps
self._max_epochs = epochs
@@ -266,19 +271,18 @@ class Trainer:
# reset hooks
self._reset_states()
- self.hooks = list()
-
- # build hooks
- if hooks_cfg is not None:
- for cfg in hooks_cfg:
- hook = build_hooks(cfg, self)
- self.hooks.append(hook)
+ if hooks is not None:
+ assert isinstance(hooks, list), f'expected argument hooks be to list, but got {type(hooks)}'
+ else:
+ hooks = []
+ self.hooks = hooks
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
- f'build {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
- self._logger.info("Lower value means higher priority for calling hook function")
+ f'Using {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
+ self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
+ self._call_hooks('after_hook_is_attached')
# start train
self._engine.train()
@@ -309,13 +313,15 @@ class Trainer:
# check for termination
if self._exceed_max_step():
self._logger.info(
- f"Max number of steps {max_steps} has been reached, training is stopped automatically")
+ f"Max number of steps {max_steps} has been reached, training is stopped automatically",
+ ranks=[0])
break
self._call_hooks('after_train')
self._call_timer('reset', 'train-epoch')
def evaluate(self,
test_dataloader: DataLoader,
+ hooks: List[BaseHook] = None,
display_progress: bool = False):
"""Evaluates the model with testing data.
@@ -327,6 +333,21 @@ class Trainer:
# set display
display_progress = self._should_display_progress(display_progress)
+ # reset hooks
+ self._reset_states()
+ if hooks is not None:
+ assert isinstance(hooks, list), f'expected argument hooks be to list, but got {type(hooks)}'
+ else:
+ hooks = []
+ self.hooks = hooks
+ self.hooks.sort(key=lambda hook: hook.priority)
+ if self._verbose:
+ for hook in self.hooks:
+ self._logger.info(
+ f'Using {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
+ self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
+ self._call_hooks('after_hook_is_attached')
+
# eval
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
@@ -351,5 +372,6 @@ class Trainer:
# for compatibility with schedule
simple_dataloader = [(data, None)]
data_iter = iter(simple_dataloader)
- output, _, _ = self._engine.step(data_iter, return_loss=False)
+ output, _, _ = self.schedule.forward_backward_step(
+ self.engine, data_iter, forward_only=True, return_loss=False)
return output
diff --git a/colossalai/trainer/hooks/_base_hook.py b/colossalai/trainer/hooks/_base_hook.py
index 4d510ab0f..e4b5edfbf 100644
--- a/colossalai/trainer/hooks/_base_hook.py
+++ b/colossalai/trainer/hooks/_base_hook.py
@@ -5,9 +5,6 @@ from abc import ABC
from torch import Tensor
-from colossalai.logging import get_global_dist_logger
-from .._trainer import Trainer
-
class BaseHook(ABC):
"""This class allows users to add desired actions in specific time points
@@ -18,27 +15,31 @@ class BaseHook(ABC):
:type trainer: Trainer
:type priority: int
"""
- def __init__(self, trainer: Trainer, priority: int) -> None:
- self.trainer = trainer
- self.priority = priority
- self.logger = get_global_dist_logger()
- def before_train(self):
+ def __init__(self, priority: int) -> None:
+ self.priority = priority
+
+ def after_hook_is_attached(self, trainer):
+ """Actions after hooks are attached to trainer.
+ """
+ pass
+
+ def before_train(self, trainer):
"""Actions before training.
"""
pass
- def after_train(self):
+ def after_train(self, trainer):
"""Actions after training.
"""
pass
- def before_train_iter(self):
+ def before_train_iter(self, trainer):
"""Actions before running a training iteration.
"""
pass
- def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
+ def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a training iteration.
:param output: Output of the model
@@ -50,42 +51,42 @@ class BaseHook(ABC):
"""
pass
- def before_train_epoch(self):
+ def before_train_epoch(self, trainer):
"""Actions before starting a training epoch.
"""
pass
- def after_train_epoch(self):
+ def after_train_epoch(self, trainer):
"""Actions after finishing a training epoch.
"""
pass
- def before_test(self):
+ def before_test(self, trainer):
"""Actions before evaluation.
"""
pass
- def after_test(self):
+ def after_test(self, trainer):
"""Actions after evaluation.
"""
pass
- def before_test_epoch(self):
+ def before_test_epoch(self, trainer):
"""Actions before starting a testing epoch.
"""
pass
- def after_test_epoch(self):
+ def after_test_epoch(self, trainer):
"""Actions after finishing a testing epoch.
"""
pass
- def before_test_iter(self):
+ def before_test_iter(self, trainer):
"""Actions before running a testing iteration.
"""
pass
- def after_test_iter(self, output: Tensor, label: Tensor, loss: Tensor):
+ def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a testing iteration.
:param output: Output of the model
@@ -97,11 +98,11 @@ class BaseHook(ABC):
"""
pass
- def init_runner_states(self, key, val):
+ def init_runner_states(self, trainer, key, val):
"""Initializes trainer's state.
:param key: Key of reseting state
:param val: Value of reseting state
"""
- if key not in self.trainer.states:
- self.trainer.states[key] = val
+ if key not in trainer.states:
+ trainer.states[key] = val
diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/trainer/hooks/_checkpoint_hook.py
index e1d9d4714..939e957bd 100644
--- a/colossalai/trainer/hooks/_checkpoint_hook.py
+++ b/colossalai/trainer/hooks/_checkpoint_hook.py
@@ -2,9 +2,9 @@
# -*- encoding: utf-8 -*-
import os.path as osp
+from colossalai.logging import get_dist_logger
from colossalai.registry import HOOKS
-from colossalai.trainer import Trainer
from colossalai.trainer.hooks import BaseHook
from colossalai.utils import is_dp_rank_0
from colossalai.utils.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
@@ -16,12 +16,10 @@ from ._lr_scheduler_hook import LRSchedulerHook
class SaveCheckpointHook(BaseHook):
"""Saves the model by interval in training process.
- :param trainer: Trainer attached with current hook
:param interval: Saving interval
:param checkpoint_dir: Directory of saving checkpoint
:param suffix: Saving suffix of the file
:param priority: Priority in the printing, hooks with small priority will be printed in front
- :type trainer: Trainer
:type interval: int, optional
:type checkpoint_dir: int, optional
:type suffix: str, optional
@@ -29,59 +27,55 @@ class SaveCheckpointHook(BaseHook):
"""
def __init__(self,
- trainer: Trainer,
interval: int = 1,
checkpoint_dir: str = None,
suffix: str = '',
priority: int = 10):
- super().__init__(trainer=trainer, priority=priority)
- assert isinstance(trainer, Trainer), \
- f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
+ super().__init__(priority=priority)
self.interval = interval
self.checkpoint_dir = checkpoint_dir
self.suffix = suffix
+ self.logger = get_dist_logger()
# get lr scheduler from the LRSchedulerHook before train
self._lr_scheduler = None
- def before_train(self):
+ def after_hook_is_attached(self, trainer):
# check if lr scheduler is present in LRSchedulerHook
- for hook in self.trainer.hooks:
+ for hook in trainer.hooks:
if isinstance(hook, LRSchedulerHook):
self._lr_scheduler = hook.lr_scheduler
break
- def after_train_epoch(self):
+ def after_train_epoch(self, trainer):
"""Saves the model after a training epoch.
"""
# save by interval
- if self.trainer.cur_epoch % self.interval == 0:
+ if trainer.cur_epoch % self.interval == 0:
# only gpus with data parallel rank equals to 0 write to the disk
if is_dp_rank_0():
save_path = get_checkpoint_path(self.checkpoint_dir,
- self.trainer.cur_epoch,
+ trainer.cur_epoch,
suffix=self.suffix)
save_checkpoint(save_path,
- self.trainer.cur_epoch,
- self.trainer.engine.model,
- self.trainer.engine.optimizer,
+ trainer.cur_epoch,
+ trainer.engine.model,
+ trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info(
- f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
+ f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])
@HOOKS.register_module
class LoadCheckpointHook(BaseHook):
"""Loads the model before training process.
- :param trainer: Trainer attached with current hook
:param checkpoint_dir: Directory of saving checkpoint
:param epoch: Epoch number to be set
:param finetune: Whether allows to load a part of the model
:param strict: Whether loads a model that has the same shape of parameters
:param priority: Priority in the printing, hooks with small priority will be printed in front
- :type trainer: Trainer
:type checkpoint_dir: str, optional
:type epoch: str, optional
:type finetune: bool, optional
@@ -90,28 +84,26 @@ class LoadCheckpointHook(BaseHook):
"""
def __init__(self,
- trainer: Trainer = None,
checkpoint_dir: str = None,
epoch: int = -1,
finetune: bool = False,
strict: bool = False,
suffix: str = '',
priority: int = 0) -> None:
- super().__init__(trainer=trainer, priority=priority)
- assert isinstance(trainer, Trainer), \
- f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
+ super().__init__(priority=priority)
self.epoch = epoch
self.checkpoint_dir = checkpoint_dir
self.finetune = finetune
self.suffix = suffix
self.strict = strict
+ self.logger = get_dist_logger()
- def before_train(self):
+ def before_train(self, trainer):
"""Loads parameters to the model before training.
"""
# check if lr scheduler is present in LRSchedulerHook
lr_scheduler = None
- for hook in self.trainer.hooks:
+ for hook in trainer.hooks:
if isinstance(hook, LRSchedulerHook):
lr_scheduler = hook.lr_scheduler
break
@@ -124,17 +116,17 @@ class LoadCheckpointHook(BaseHook):
if osp.exists(path):
last_epoch, _ = load_checkpoint(path,
- self.trainer.engine.model,
- self.trainer.engine.optimizer,
+ trainer.engine.model,
+ trainer.engine.optimizer,
lr_scheduler,
finetune=self.finetune,
strict=self.strict)
if self.finetune:
- self.trainer.cur_epoch = 0
+ trainer.cur_epoch = 0
else:
- self.trainer.cur_epoch = last_epoch
+ trainer.cur_epoch = last_epoch
self.logger.info(
- f'loaded checkpoint from {path}')
+ f'loaded checkpoint from {path}', ranks=[0])
else:
raise FileNotFoundError(f'checkpoint is not found at {path}')
diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/trainer/hooks/_log_hook.py
index 3c3fdfc43..8693dd515 100644
--- a/colossalai/trainer/hooks/_log_hook.py
+++ b/colossalai/trainer/hooks/_log_hook.py
@@ -6,35 +6,40 @@ import os.path as osp
import torch
from torch.utils.tensorboard import SummaryWriter
-
+from typing import List
+from decimal import Decimal
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS
-from colossalai.trainer._trainer import Trainer
-from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
- is_tp_rank_0, is_no_pp_or_last_stage
+from colossalai.logging import DistributedLogger
+from colossalai.utils import report_memory_usage, is_dp_rank_0, \
+ is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
from ._base_hook import BaseHook
def _format_number(val):
if isinstance(val, float):
- return f'{val:.5f}'
- elif torch.is_floating_point(val):
- return f'{val.item():.5f}'
+ return f'{val:.5g}'
+ elif torch.is_tensor(val) and torch.is_floating_point(val):
+ return f'{val.item():.5g}'
return val
-class EpochIntervalHook(BaseHook):
- def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
- super().__init__(trainer, priority)
+class LogByEpochHook(BaseHook):
+ def __init__(self,
+ logger,
+ interval: int = 1,
+ priority: int = 1):
+ super().__init__(priority)
+ self.logger = logger
self._interval = interval
- def _is_epoch_to_log(self):
- return self.trainer.cur_epoch % self._interval == 0
+ def _is_epoch_to_log(self, trainer):
+ return trainer.cur_epoch % self._interval == 0
@HOOKS.register_module
-class LogMetricByEpochHook(EpochIntervalHook):
+class LogMetricByEpochHook(LogByEpochHook):
"""Specialized Hook to record the metric to log.
:param trainer: Trainer attached with current hook
@@ -45,32 +50,35 @@ class LogMetricByEpochHook(EpochIntervalHook):
:type priority: int, optional
"""
- def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10) -> None:
- super().__init__(trainer=trainer, interval=interval, priority=priority)
+ def __init__(self,
+ logger,
+ interval: int = 1,
+ priority: int = 10) -> None:
+ super().__init__(logger, interval, priority)
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
- def _get_str(self, mode):
+ def _get_str(self, trainer, mode):
msg = []
- for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
+ for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
msg.append(
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
msg = ', '.join(msg)
return msg
- def after_train_epoch(self):
- if self._is_epoch_to_log():
- msg = self._get_str(mode='train')
+ def after_train_epoch(self, trainer):
+ if self._is_epoch_to_log(trainer):
+ msg = self._get_str(trainer=trainer, mode='train')
if self._is_rank_to_log:
self.logger.info(
- f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
+ f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
- def after_test_epoch(self):
- if self._is_epoch_to_log():
- msg = self._get_str(mode='test')
+ def after_test_epoch(self, trainer):
+ if self._is_epoch_to_log(trainer):
+ msg = self._get_str(trainer=trainer, mode='test')
if self._is_rank_to_log:
self.logger.info(
- f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
+ f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
@HOOKS.register_module
@@ -86,74 +94,79 @@ class TensorboardHook(BaseHook):
"""
def __init__(self,
- trainer: Trainer,
log_dir: str,
- dp_rank_0_only: bool = True,
- tp_rank_0_only: bool = True,
+ ranks: List = None,
+ parallel_mode: ParallelMode = ParallelMode.GLOBAL,
priority: int = 10,
) -> None:
- super().__init__(trainer=trainer, priority=priority)
+ super().__init__(priority=priority)
# create log dir
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
os.makedirs(log_dir, exist_ok=True)
# determine the ranks to generate tensorboard logs
- self._is_valid_rank_to_log = is_no_pp_or_last_stage()
+ self._is_valid_rank_to_log = False
+ if not gpc.is_initialized(parallel_mode):
+ self._is_valid_rank_to_log = True
+ else:
+ local_rank = gpc.get_local_rank(parallel_mode)
- if dp_rank_0_only:
- self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_dp_rank_0()
+ if ranks is None or local_rank in ranks:
+ self._is_valid_rank_to_log = True
- if tp_rank_0_only:
- self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_tp_rank_0()
+ # check for
+ if gpc.is_initialized(ParallelMode.PIPELINE) and \
+ not gpc.is_last_rank(ParallelMode.PIPELINE) and self._is_valid_rank_to_log:
+ raise ValueError("Tensorboard hook can only log on the last rank of pipeline process group")
if self._is_valid_rank_to_log:
# create workspace on only one rank
- if gpc.is_initialized(ParallelMode.GLOBAL):
- rank = gpc.get_global_rank()
+ if gpc.is_initialized(parallel_mode):
+ rank = gpc.get_local_rank(parallel_mode)
else:
rank = 0
# create workspace
- log_dir = osp.join(log_dir, f'rank_{rank}')
+ log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}')
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
- def _log_by_iter(self, mode: str):
- for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
+ def _log_by_iter(self, trainer, mode: str):
+ for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
continue
val = metric_calculator.get_last_step_value()
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val,
- self.trainer.cur_step)
+ trainer.cur_step)
- def _log_by_epoch(self, mode: str):
- for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
+ def _log_by_epoch(self, trainer, mode: str):
+ for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val,
- self.trainer.cur_step)
+ trainer.cur_step)
- def after_test_iter(self, *args):
- self._log_by_iter(mode='test')
+ def after_test_iter(self, trainer, *args):
+ self._log_by_iter(trainer, mode='test')
- def after_test_epoch(self):
- self._log_by_epoch(mode='test')
+ def after_test_epoch(self, trainer):
+ self._log_by_epoch(trainer, mode='test')
- def after_train_iter(self, *args):
- self._log_by_iter(mode='train')
+ def after_train_iter(self, trainer, *args):
+ self._log_by_iter(trainer, mode='train')
- def after_train_epoch(self):
- self._log_by_epoch(mode='train')
+ def after_train_epoch(self, trainer):
+ self._log_by_epoch(trainer, mode='train')
@HOOKS.register_module
-class LogTimingByEpochHook(EpochIntervalHook):
+class LogTimingByEpochHook(LogByEpochHook):
"""Specialized Hook to write timing record to log.
:param trainer: Trainer attached with current hook
@@ -167,53 +180,61 @@ class LogTimingByEpochHook(EpochIntervalHook):
"""
def __init__(self,
- trainer: Trainer,
+ timer: MultiTimer,
+ logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
- log_eval: bool = True
+ log_eval: bool = True,
+ ignore_num_train_steps: int = 0
) -> None:
- super().__init__(trainer=trainer, interval=interval, priority=priority)
- set_global_multitimer_status(True)
- self._global_timer = get_global_multitimer()
+ super().__init__(logger=logger, interval=interval, priority=priority)
+ self._timer = timer
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
+ # extra handling to avoid the unstable readings of the first
+ # few training steps to affect the history mean time
+ self._ignore_num_train_steps = ignore_num_train_steps
+ self._is_train_step_history_trimmed = False
+
def _get_message(self):
msg = []
- for timer_name, timer in self._global_timer:
+ for timer_name, timer in self._timer:
last_elapsed_time = timer.get_elapsed_time()
if timer.has_history:
+ if timer_name == 'train-step' and not self._is_train_step_history_trimmed:
+ timer._history = timer._history[self._ignore_num_train_steps:]
+ self._is_train_step_history_trimmed = True
history_mean = timer.get_history_mean()
history_sum = timer.get_history_sum()
msg.append(
- f'{timer_name}: last elapsed time = {last_elapsed_time}, '
- f'history sum = {history_sum}, history mean = {history_mean}')
+ f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s')
else:
msg.append(
- f'{timer_name}: last elapsed time = {last_elapsed_time}')
+ f'{timer_name}: last = {_format_number(last_elapsed_time)} s')
msg = ', '.join(msg)
return msg
- def after_train_epoch(self):
+ def after_train_epoch(self, trainer):
"""Writes log after finishing a training epoch.
"""
- if self._is_epoch_to_log() and self._is_rank_to_log:
+ if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
msg = self._get_message()
self.logger.info(
- f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
+ f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}, num steps per epoch={trainer.steps_per_epoch}')
- def after_test_epoch(self):
+ def after_test_epoch(self, trainer):
"""Writes log after finishing a testing epoch.
"""
- if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval:
+ if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
msg = self._get_message()
self.logger.info(
- f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
+ f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
@HOOKS.register_module
-class LogMemoryByEpochHook(EpochIntervalHook):
+class LogMemoryByEpochHook(LogByEpochHook):
"""Specialized Hook to write memory usage record to log.
:param trainer: Trainer attached with current hook
@@ -227,33 +248,34 @@ class LogMemoryByEpochHook(EpochIntervalHook):
"""
def __init__(self,
- trainer: Trainer,
+ logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
- log_eval: bool = True
+ log_eval: bool = True,
+ report_cpu: bool = False
) -> None:
- super().__init__(trainer=trainer, interval=interval, priority=priority)
- set_global_multitimer_status(True)
- self._global_timer = get_global_multitimer()
+ super().__init__(logger=logger, interval=interval, priority=priority)
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
- def before_train(self):
+ def before_train(self, trainer):
"""Resets before training.
"""
- if self._is_epoch_to_log() and self._is_rank_to_log:
- report_memory_usage('before-train')
+ if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
+ report_memory_usage('before-train', self.logger)
- def after_train_epoch(self):
+ def after_train_epoch(self, trainer):
"""Writes log after finishing a training epoch.
"""
- if self._is_epoch_to_log() and self._is_rank_to_log:
+ if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
report_memory_usage(
- f'After Train - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')
+ f'After Train - Epoch {trainer.cur_epoch} - {self.__class__.__name__}',
+ self.logger)
- def after_test(self):
+ def after_test(self, trainer):
"""Reports after testing.
"""
- if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval:
+ if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
report_memory_usage(
- f'After Test - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')
+ f'After Test - Epoch {trainer.cur_epoch} - {self.__class__.__name__}',
+ self.logger)
diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/trainer/hooks/_lr_scheduler_hook.py
index ca483aebe..d5bbe7591 100644
--- a/colossalai/trainer/hooks/_lr_scheduler_hook.py
+++ b/colossalai/trainer/hooks/_lr_scheduler_hook.py
@@ -3,7 +3,6 @@ from torch import Tensor
from colossalai.builder import build_lr_scheduler
from colossalai.registry import HOOKS
from ._metric_hook import MetricHook
-from .._trainer import Trainer
from ..metric import LearningRate
@@ -22,37 +21,26 @@ class LRSchedulerHook(MetricHook):
"""
def __init__(self,
- trainer: Trainer,
- lr_scheduler_cfg: dict,
- by_epoch: bool = True,
+ lr_scheduler,
+ by_epoch: bool,
store_lr_in_state: bool = True,
priority: int = 1,
):
- super().__init__(trainer=trainer, priority=priority)
+ super().__init__(priority=priority)
self.by_epoch = by_epoch
+ self.lr_scheduler = lr_scheduler
+ self.store_lr_in_state = store_lr_in_state
- if by_epoch:
- total_steps = trainer.max_epochs
- else:
- total_steps = trainer.max_epochs * trainer.steps_per_epoch
- if trainer.max_steps is not None:
- total_steps = min(total_steps, trainer.max_steps)
+ def after_hook_is_attached(self, trainer):
+ trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=self.by_epoch,
+ initial_lr=self.lr_scheduler.get_last_lr()[0])
- lr_scheduler_cfg['total_steps'] = total_steps
-
- self.lr_scheduler = build_lr_scheduler(
- lr_scheduler_cfg, trainer.engine.optimizer)
-
- if store_lr_in_state:
- self.trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=by_epoch,
- initial_lr=self.lr_scheduler.get_lr()[0])
-
- def after_train_epoch(self):
+ def after_train_epoch(self, trainer):
if self.by_epoch:
self.lr_scheduler.step()
- self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
+ trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
- def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
+ def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
if not self.by_epoch:
self.lr_scheduler.step()
- self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
+ trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py
index 8c3478c71..aa2e22fa0 100644
--- a/colossalai/trainer/hooks/_metric_hook.py
+++ b/colossalai/trainer/hooks/_metric_hook.py
@@ -5,8 +5,7 @@ from colossalai.context import ParallelMode
from colossalai.registry import HOOKS
from colossalai.utils import is_no_pp_or_last_stage
from ._base_hook import BaseHook
-from .._trainer import Trainer
-from ..metric import Loss, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D
+from ..metric import Loss, Accuracy1D, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D
class MetricHook(BaseHook):
@@ -22,16 +21,14 @@ class MetricHook(BaseHook):
"""
def __init__(self,
- trainer: Trainer,
priority: int,
):
- super().__init__(trainer, priority)
+ super().__init__(priority)
self._is_stage_to_compute = is_no_pp_or_last_stage()
- self._check_metric_states_initialization()
- def _check_metric_states_initialization(self):
- if 'metrics' not in self.trainer.states:
- self.init_runner_states('metrics', dict(train={}, test={}))
+ def _check_metric_states_initialization(self, trainer):
+ if 'metrics' not in trainer.states:
+ self.init_runner_states(trainer, 'metrics', dict(train={}, test={}))
@HOOKS.register_module
@@ -44,36 +41,71 @@ class LossHook(MetricHook):
:type priority: int, optional
"""
- def __init__(self, trainer: Trainer, priority: int = 0):
- super().__init__(trainer, priority)
+ def __init__(self, priority: int = 0):
+ super().__init__(priority)
+
+ def after_hook_is_attached(self, trainer):
+ self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.train_loss = Loss(epoch_only=False)
self.test_loss = Loss(epoch_only=True)
# register the metric calculator
- self.trainer.states['metrics']['train'][
+ trainer.states['metrics']['train'][
self.train_loss.__class__.__name__] = self.train_loss
- self.trainer.states['metrics']['test'][
+ trainer.states['metrics']['test'][
self.test_loss.__class__.__name__] = self.test_loss
- def before_train_epoch(self):
+ def before_train_epoch(self, trainer):
if self._is_stage_to_compute:
self.train_loss.reset()
- def after_train_iter(self, logits, label, loss):
+ def after_train_iter(self, trainer, logits, label, loss):
if self._is_stage_to_compute:
self.train_loss.update(loss)
- def before_test_epoch(self):
+ def before_test_epoch(self, trainer):
if self._is_stage_to_compute:
self.test_loss.reset()
- def after_test_iter(self, logits, label, loss):
+ def after_test_iter(self, trainer, logits, label, loss):
if self._is_stage_to_compute:
self.test_loss.update(loss)
+@HOOKS.register_module
+class Accuracy1DHook(MetricHook):
+ """Specialized hook class for :class:`Accuracy1D`.
+ It acts the same as :class:`AccuracyHook`.
+
+ :param trainer: Trainer attached with current hook
+ :param priority: Priority in the printing, hooks with small priority will be printed in front
+ :type trainer: Trainer
+ :type priority: int, optional
+ """
+
+ def __init__(self, priority: int = 10):
+ super().__init__(priority)
+
+ def after_hook_is_attached(self, trainer):
+ self._check_metric_states_initialization(trainer)
+ if self._is_stage_to_compute:
+ self.metric = Accuracy1D(epoch_only=True)
+
+ # register the metric
+ trainer.states['metrics']['test'][
+ self.metric.__class__.__name__] = self.metric
+
+ def before_test(self, trainer):
+ if self._is_stage_to_compute:
+ self.metric.reset()
+
+ def after_test_iter(self, trainer, logits, label, *args):
+ if self._is_stage_to_compute:
+ self.metric.update(logits, label)
+
+
@HOOKS.register_module
class Accuracy2DHook(MetricHook):
"""Specialized hook class for :class:`Accuracy2D`.
@@ -85,42 +117,46 @@ class Accuracy2DHook(MetricHook):
:type priority: int, optional
"""
- def __init__(self, trainer: Trainer, priority: int = 0):
- super().__init__(trainer, priority)
+ def __init__(self, priority: int = 0):
+ super().__init__(priority)
+ def after_hook_is_attached(self, trainer):
+ self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = Accuracy2D(epoch_only=True)
# register the metric
- self.trainer.states['metrics']['test'][
+ trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
- def before_test(self):
+ def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
- def after_test_iter(self, logits, label, *args):
+ def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@HOOKS.register_module
class Accuracy2p5DHook(MetricHook):
- def __init__(self, trainer: Trainer, priority: int = 0):
- super().__init__(trainer, priority)
+ def __init__(self, priority: int = 0):
+ super().__init__(priority)
+ def after_hook_is_attached(self, trainer):
+ self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = Accuracy2p5D(epoch_only=True)
# register the metric
- self.trainer.states['metrics']['test'][
+ trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
- def before_test(self):
+ def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
- def after_test_iter(self, logits, label, *args):
+ def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@@ -136,26 +172,22 @@ class Accuracy3DHook(MetricHook):
"""
def __init__(self,
- trainer: Trainer,
- input_parallel_mode: ParallelMode,
- weight_parallel_mode: ParallelMode,
priority: int = 10):
- super().__init__(trainer, priority)
+ super().__init__(priority)
+ def after_hook_is_attached(self, trainer):
if self._is_stage_to_compute:
- self.metric = Accuracy3D(epoch_only=True,
- input_parallel_mode=input_parallel_mode,
- weight_parallel_mode=weight_parallel_mode)
+ self.metric = Accuracy3D(epoch_only=True)
# register the metric
- self.trainer.states['metrics']['test'][
+ trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
- def before_test(self):
+ def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
- def after_test_iter(self, logits, label, *args):
+ def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@@ -170,20 +202,21 @@ class AccuracyHook(MetricHook):
:type priority: int
"""
- def __init__(self, trainer: Trainer, priority: int = 0):
- super().__init__(trainer, priority)
+ def __init__(self, priority: int = 0):
+ super().__init__(priority)
+ def after_hook_is_attached(self, trainer):
if self._is_stage_to_compute:
self.metric = Accuracy(epoch_only=True)
# register the metric
- self.trainer.states['metrics']['test'][
+ trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
- def before_test(self):
+ def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
- def after_test_iter(self, logits, label, *args):
+ def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
diff --git a/colossalai/trainer/metric.py b/colossalai/trainer/metric.py
index b595d37b8..5038826c9 100644
--- a/colossalai/trainer/metric.py
+++ b/colossalai/trainer/metric.py
@@ -3,12 +3,14 @@ from abc import ABC, abstractmethod
import torch
import torch.distributed as dist
-
from colossalai.communication import all_gather
+from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
+ WEIGHT_GROUP_3D)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._parallel_utilities import _gather
-from colossalai.nn.layer.parallel_3d._utils import get_last_group
+from colossalai.nn.layer.parallel_3d._utils import (get_last_group,
+ get_parallel_mode_from_env)
from colossalai.utils import get_current_device
@@ -22,7 +24,6 @@ class Metric(ABC):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
-
def __init__(self, epoch_only: bool):
# is the metric only read for the full epoch
self._epoch_only = epoch_only
@@ -80,7 +81,6 @@ class Loss(Metric):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
-
def __init__(self, epoch_only):
super().__init__(epoch_only=epoch_only)
self.last_step_loss = torch.zeros(1, device=get_current_device())
@@ -110,7 +110,8 @@ class Loss(Metric):
"""Returns accumulated loss.
"""
if gpc.is_initialized(ParallelMode.DATA):
- dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM,
+ dist.all_reduce(self.accum_loss,
+ op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.DATA))
self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA))
@@ -132,7 +133,6 @@ class LearningRate(Metric):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
-
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
super().__init__(epoch_only=epoch_only)
self.lr = 0.
@@ -160,7 +160,6 @@ class Accuracy(Metric):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
-
def __init__(self, epoch_only: bool):
super().__init__(epoch_only=epoch_only)
self.last_step_sum = torch.zeros(1, device=get_current_device())
@@ -211,12 +210,42 @@ class Accuracy(Metric):
def is_better(a, b) -> bool:
return a > b
-
class Accuracy2D(Accuracy):
"""A metric collector for accuracy. It only works for classification
tasks. This class is the same as :class:`Accuracy` but used in 2D
model parallelism.
+ :param epoch_only: Whether the metric only read for the full epoch
+ :type epoch_only: bool
+ """
+ def __init__(self, epoch_only: bool):
+ super().__init__(epoch_only=epoch_only)
+
+ def update(self, logits, label) -> None:
+ if isinstance(logits, (list, tuple)):
+ logits = logits[0]
+ if isinstance(label, (list, tuple)):
+ label = label[0]
+
+ logits = _gather(logits, ParallelMode.PARALLEL_2D_ROW, 1)
+ logits = _gather(
+ logits,
+ ParallelMode.PARALLEL_2D_COL,
+ 0,
+ )
+ # update
+ preds = torch.argmax(logits, dim=-1)
+ correct = torch.sum(label == preds)
+ self.last_step_sum.fill_(label.size(0))
+ self.last_step_correct.fill_(correct)
+ self.accumulated_sum += self.last_step_sum
+ self.accumulated_correct += self.last_step_correct
+
+class Accuracy1D(Accuracy):
+ """A metric collector for accuracy. It only works for classification
+ tasks. This class is the same as :class:`Accuracy` but used in 2D
+ model parallelism.
+
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
@@ -232,14 +261,10 @@ class Accuracy2D(Accuracy):
logits = _gather(
logits,
- ParallelMode.PARALLEL_2D_ROW,
+ ParallelMode.PARALLEL_1D,
1
)
- logits = _gather(
- logits,
- ParallelMode.PARALLEL_2D_COL,
- 0,
- )
+
# update
preds = torch.argmax(logits, dim=-1)
correct = torch.sum(label == preds)
@@ -259,11 +284,7 @@ class Accuracy2p5D(Accuracy):
if isinstance(label, (list, tuple)):
label = label[0]
- logits = _gather(
- logits,
- ParallelMode.PARALLEL_2P5D_ROW,
- 1
- )
+ logits = _gather(logits, ParallelMode.PARALLEL_2P5D_ROW, 1)
logits = _gather(
logits,
ParallelMode.PARALLEL_2P5D_COL,
@@ -298,14 +319,14 @@ class Accuracy3D(Accuracy):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
-
- def __init__(self, epoch_only, input_parallel_mode, weight_parallel_mode):
+ def __init__(self, epoch_only):
+ # input_parallel_mode, weight_parallel_mode):
super().__init__(epoch_only=epoch_only)
self.depth = int(os.environ['DEPTH_3D'])
- self.input_parallel_mode = input_parallel_mode
- self.weight_parallel_mode = weight_parallel_mode
- self.output_parallel_mode = get_last_group(input_parallel_mode,
- weight_parallel_mode)
+ self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ self.output_parallel_mode = get_last_group(self.input_parallel_mode,
+ self.weight_parallel_mode)
def update(self, logits, target):
if isinstance(logits, (list, tuple)):
@@ -321,6 +342,7 @@ class Accuracy3D(Accuracy):
target = torch.chunk(target, self.depth, dim=0)[j]
logits = all_gather(logits, -1, self.output_parallel_mode)
+ logits = torch.cat(logits, dim=-1)
prediction = torch.argmax(logits, dim=-1)
correct = torch.sum(prediction == target)
diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py
index f7ef2259b..b155ad0a3 100644
--- a/colossalai/utils/__init__.py
+++ b/colossalai/utils/__init__.py
@@ -1,22 +1,26 @@
from .activation_checkpoint import checkpoint
-from .common import print_rank_0, sync_model_param_in_dp, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
+from .common import (print_rank_0, sync_model_param_in_dp, is_dp_rank_0,
+ is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp,
+ is_using_pp, conditional_context, is_model_parallel_parameter,
+ clip_grad_norm_fp32, count_zeros_fp32, copy_tensor_parallel_attributes,
+ param_is_not_tensor_parallel_duplicate)
from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda
from .memory import report_memory_usage
from .timer import MultiTimer, Timer
+from .multi_tensor_apply import multi_tensor_applier
+from .gradient_accumulation import accumulate_gradient
+from .data_sampler import DataParallelSampler, get_dataloader
-_GLOBAL_MULTI_TIMER = MultiTimer(on=False)
-
-
-def get_global_multitimer():
- return _GLOBAL_MULTI_TIMER
-
-
-def set_global_multitimer_status(mode: bool):
- _GLOBAL_MULTI_TIMER.set_status(mode)
-
-
-__all__ = ['checkpoint', 'print_rank_0', 'sync_model_param_in_dp', 'get_current_device',
- 'synchronize', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer',
- 'get_global_multitimer', 'set_global_multitimer_status',
- 'is_dp_rank_0', 'is_tp_rank_0', 'is_no_pp_or_last_stage'
+__all__ = ['checkpoint',
+ 'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0',
+ 'is_tp_rank_0', 'is_no_pp_or_last_stage', 'is_using_ddp',
+ 'is_using_pp', 'conditional_context', 'is_model_parallel_parameter',
+ 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
+ 'param_is_not_tensor_parallel_duplicate',
+ 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
+ 'report_memory_usage',
+ 'Timer', 'MultiTimer',
+ 'multi_tensor_applier',
+ 'accumulate_gradient',
+ 'DataParallelSampler', 'get_dataloader'
]
diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py
index d8c6663ba..ed4523c75 100644
--- a/colossalai/utils/common.py
+++ b/colossalai/utils/common.py
@@ -1,8 +1,21 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-import torch.distributed as dist
+import torch
+from torch._six import inf
+try:
+ import colossal_C
+except:
+ pass
+
+import torch.distributed as dist
+from contextlib import contextmanager
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from .multi_tensor_apply import multi_tensor_applier
+from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES, NUM_PARTITIONS
+import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@@ -18,7 +31,6 @@ def print_rank_0(msg: str, logger=None):
print(msg, flush=True)
else:
logger.info(msg)
- # print(msg, flush=True)
def sync_model_param_in_dp(model):
@@ -26,17 +38,214 @@ def sync_model_param_in_dp(model):
:param model: A pyTorch nn.model on whose parameters you check the consistency
'''
-
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
for param in model.parameters():
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
- dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
+ dist.broadcast(
+ param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
+
def is_dp_rank_0():
return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA)
+
def is_tp_rank_0():
return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR)
+
def is_no_pp_or_last_stage():
- return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
\ No newline at end of file
+ return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
+
+
+def is_using_ddp():
+ return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1
+
+
+def is_using_pp():
+ return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1
+
+
+@contextmanager
+def conditional_context(context_manager, enable=True):
+ if enable:
+ with context_manager:
+ yield
+ else:
+ yield
+
+
+def is_model_parallel_parameter(p):
+ return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
+
+
+def _calc_l2_norm(grads):
+ norm = 0.0
+ if len(grads) > 0:
+ dummy_overflow_buf = torch.cuda.IntTensor([0])
+ norm, _ = multi_tensor_applier(
+ colossal_C.multi_tensor_l2norm,
+ dummy_overflow_buf,
+ [grads],
+ False # no per-parameter norm
+ )
+ return norm
+
+
+def _calc_lp(grads, norm_type):
+ norm = 0.0
+ for grad in grads:
+ grad_norm = torch.norm(grad, norm_type)
+ norm += grad_norm ** norm_type
+ return norm
+
+# ======== Gradient Clipping =========
+
+
+def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
+ """Clips gradient norm of an iterable of parameters whose gradients
+ are in fp32.
+
+ This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
+ added functionality to handle model parallel parameters. Note that
+ the gradients are modified in place.
+
+ Arguments:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+
+ Returns:
+ Total norm of the parameters (viewed as a single vector).
+ """
+
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+
+ # Filter parameters based on:
+ # - grad should not be none
+ # - parameter should not be shared
+ # - should not be a replica due to tensor model parallelism
+ params = []
+ for param in parameters:
+ if param.grad is not None:
+ # Make sure the grads are in fp32
+ assert param.grad.type() == 'torch.cuda.FloatTensor', \
+ f'expected gradient to be dtype torch.cuda.FloatTensor, but got {param.grad.type()}'
+ params.append(param)
+ # Norm parameters.
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+
+ # Calculate norm.
+ if norm_type == inf:
+ total_norm = max(p.grad.data.abs().max() for p in params)
+ total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
+ ops = []
+ # Take max across all model-parallel GPUs.
+ if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1:
+ ops.append(dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.MAX,
+ group=gpc.get_group(
+ ParallelMode.TENSOR),
+ async_op=True))
+ if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
+ ops.append(dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.MAX,
+ group=gpc.get_group(
+ ParallelMode.PIPELINE),
+ async_op=True))
+ for req in ops:
+ req.wait()
+ total_norm = total_norm_cuda[0].item()
+ else:
+ tensor_parallel_grads = []
+ no_tensor_parallel_grads = []
+ for p in params:
+ if is_model_parallel_parameter(p):
+ reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type)
+ tensor_parallel_grads.append(p.grad.data / reductor)
+ else:
+ no_tensor_parallel_grads.append(p.grad.data)
+ if norm_type == 2.0:
+ tensor_parallel_norm = _calc_l2_norm(
+ tensor_parallel_grads) ** norm_type
+ no_tensor_parallel_norm = _calc_l2_norm(
+ no_tensor_parallel_grads) ** norm_type
+ else:
+ tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
+ no_tensor_parallel_grads = _calc_lp(
+ no_tensor_parallel_grads, norm_type)
+ # Sum across all model-parallel GPUs.
+ if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
+ dist.all_reduce(tensor_parallel_norm,
+ op=dist.ReduceOp.SUM,
+ group=gpc.get_group(ParallelMode.TENSOR))
+ total_norm = tensor_parallel_norm + no_tensor_parallel_norm
+ if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
+ dist.all_reduce(total_norm,
+ op=dist.ReduceOp.SUM,
+ group=gpc.get_group(ParallelMode.PIPELINE))
+ total_norm = total_norm ** (1.0 / norm_type)
+ if type(total_norm) == 'torch.cuda.FloatTensor':
+ total_norm = total_norm.item()
+
+ # Scale.
+ clip_coeff = max_norm / (total_norm + 1.0e-6)
+ if clip_coeff < 1.0:
+ grads = [p.grad.detach() for p in params]
+ dummy_overflow_buf = torch.cuda.IntTensor([0])
+ multi_tensor_applier(colossal_C.multi_tensor_scale,
+ dummy_overflow_buf,
+ [grads, grads],
+ clip_coeff)
+
+ return total_norm
+
+
+def count_zeros_fp32(parameters):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+
+ # Filter parameters based on:
+ # - grad should not be none
+ # - parameter should not be shared
+ # - should not be a replica due to tensor model parallelism
+ total_num_zeros = 0.0
+ for param in parameters:
+ grad_not_none = param.grad is not None
+ is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
+ if grad_not_none and is_not_tp_duplicate:
+ grad = param.grad.detach()
+ num_zeros = grad.numel() - torch.count_nonzero(grad)
+ total_num_zeros = num_zeros + total_num_zeros
+
+ # Sum across all model-parallel GPUs.
+ ops = []
+ ops.append(dist.all_reduce(total_num_zeros,
+ op=dist.ReduceOp.SUM,
+ group=gpc.get_group(ParallelMode.TENSOR),
+ async_op=True))
+ ops.append(dist.all_reduce(total_num_zeros,
+ op=dist.ReduceOp.SUM,
+ group=gpc.get_group(ParallelMode.PIPELINE),
+ async_op=True))
+ for req in ops:
+ req.wait()
+ total_num_zeros = total_num_zeros.item()
+
+ return total_num_zeros
+
+
+def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
+ for attr in TENSOR_PARALLEL_ATTRIBUTES:
+ if hasattr(src_tensor, attr):
+ val = getattr(src_tensor, attr)
+ setattr(dst_tensor, attr, val)
+
+
+def param_is_not_tensor_parallel_duplicate(param):
+ return (hasattr(param, IS_TENSOR_PARALLEL) and
+ getattr(param, IS_TENSOR_PARALLEL)) or (
+ gpc.get_local_rank(ParallelMode.TENSOR) == 0)
diff --git a/colossalai/utils/data_sampler/__init__.py b/colossalai/utils/data_sampler/__init__.py
new file mode 100644
index 000000000..12798a94c
--- /dev/null
+++ b/colossalai/utils/data_sampler/__init__.py
@@ -0,0 +1,4 @@
+from .base_sampler import BaseSampler
+from .data_parallel_sampler import DataParallelSampler, get_dataloader
+
+__all__ = ['BaseSampler', 'DataParallelSampler', 'get_dataloader']
diff --git a/colossalai/nn/data/sampler/base_sampler.py b/colossalai/utils/data_sampler/base_sampler.py
similarity index 100%
rename from colossalai/nn/data/sampler/base_sampler.py
rename to colossalai/utils/data_sampler/base_sampler.py
diff --git a/colossalai/nn/data/sampler/data_parallel_sampler.py b/colossalai/utils/data_sampler/data_parallel_sampler.py
similarity index 68%
rename from colossalai/nn/data/sampler/data_parallel_sampler.py
rename to colossalai/utils/data_sampler/data_parallel_sampler.py
index 2b3817e03..afd20add2 100644
--- a/colossalai/nn/data/sampler/data_parallel_sampler.py
+++ b/colossalai/utils/data_sampler/data_parallel_sampler.py
@@ -3,19 +3,21 @@
# adpated from torch.utils.data.DistributedSampler
import math
+import random
+import numpy as np
from typing import TypeVar, Iterator
import torch
-from torch.utils.data import Sampler, Dataset
+from torch.utils.data import Sampler, Dataset, DataLoader
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import SAMPLERS
+from colossalai.registry import DATA_SAMPLERS
T_co = TypeVar('T_co', covariant=True)
-@SAMPLERS.register_module
+@DATA_SAMPLERS.register_module
class DataParallelSampler(Sampler):
"""A data sampler for distributed data parallelism
@@ -66,6 +68,10 @@ class DataParallelSampler(Sampler):
g.manual_seed(self.seed + self.epoch)
# type: ignore[arg-type]
indices = torch.randperm(len(self.dataset), generator=g).tolist()
+
+ # update for next epoch so that there is no need to call
+ # set_epoch manually
+ self.epoch += 1
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
@@ -100,3 +106,44 @@ class DataParallelSampler(Sampler):
:type epoch: int
"""
self.epoch = epoch
+
+
+def get_dataloader(dataset, shuffle=False, seed=1024, add_sampler=True, **kwargs):
+ '''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
+
+ .. note: when pipeline parallel is enabled, shuffle cannot be True
+ as it will result in mismatch between input data on the 1st
+ stage and label on the last stage
+
+ :param dataset: a :class:utils.data.dataset dataset
+ :param seed: random worker seed, defaults to 1024
+ :type seed: int, optional
+ :param add_sampler_if_possible: [description], defaults to False
+ :type add_sampler_if_possible: bool, optional
+ :return: a :class:utils.data.dataset dataloader
+ :rtype: torch.utils.data.dataset
+ '''
+ _kwargs = kwargs.copy()
+
+ if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
+ sampler = DataParallelSampler(dataset, shuffle=shuffle)
+ else:
+ sampler = None
+
+ # Deterministic dataloader
+ def seed_worker(worker_id):
+ worker_seed = seed
+ np.random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ random.seed(worker_seed)
+
+ if sampler is None:
+ return DataLoader(dataset,
+ worker_init_fn=seed_worker,
+ shuffle=shuffle,
+ **_kwargs)
+ else:
+ return DataLoader(dataset,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ **_kwargs)
diff --git a/colossalai/utils/gradient_accumulation/__init__.py b/colossalai/utils/gradient_accumulation/__init__.py
new file mode 100644
index 000000000..342f360c1
--- /dev/null
+++ b/colossalai/utils/gradient_accumulation/__init__.py
@@ -0,0 +1,29 @@
+import torch.nn as nn
+from typing import List
+from colossalai.engine import BaseGradientHandler
+from typing import Iterable
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
+
+
+def accumulate_gradient(model: nn.Module,
+ optimizer: Optimizer,
+ dataloader: Iterable,
+ accumulate_size: int,
+ gradient_handlers: List[BaseGradientHandler] = None,
+ lr_scheduler: _LRScheduler = None):
+ optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model)
+ dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size)
+
+ if gradient_handlers is not None:
+ gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers]
+
+ if lr_scheduler is not None:
+ lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size)
+
+ return optimizer, dataloader, gradient_handlers, lr_scheduler
+
+
+__all__ = ['accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer',
+ 'GradAccumLrSchedulerByStep', 'GradAccumGradientHandler']
diff --git a/colossalai/utils/gradient_accumulation/_gradient_accumulation.py b/colossalai/utils/gradient_accumulation/_gradient_accumulation.py
new file mode 100644
index 000000000..0aa25188a
--- /dev/null
+++ b/colossalai/utils/gradient_accumulation/_gradient_accumulation.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import torch.nn as nn
+from torch import Tensor
+from typing import Iterable, Any
+from colossalai.nn.optimizer import ColossalaiOptimizer
+from torch.nn.parallel.distributed import DistributedDataParallel
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader
+from colossalai.utils import conditional_context
+from colossalai.engine import BaseGradientHandler
+
+
+class GradAccumOptimizer(ColossalaiOptimizer):
+
+ def __init__(self, optim: Optimizer, accumulate_size: int, model: nn.Module = None):
+ super().__init__(optim)
+ self.accumulate_size = accumulate_size
+ self.accumulate_step = 0
+
+ # handle pytorch ddp auto all reduce
+ self.model = model
+ self.is_torch_ddp = isinstance(self.model, DistributedDataParallel)
+
+ def zero_grad(self, *args, **kwargs):
+ if self.accumulate_step == 0:
+ self.optim.zero_grad(*args, **kwargs)
+
+ def step(self, *args, **kwargs):
+ if self.accumulate_step < self.accumulate_size:
+ return None
+ else:
+ self.accumulate_step = 0
+ return self.optim.step(*args, **kwargs)
+
+ def clip_grad_norm(self, model: nn.Module, max_norm: float):
+ if self.accumulate_step < self.accumulate_size:
+ pass
+ else:
+ self.optim.clip_grad_norm(model, max_norm)
+
+ def backward(self, loss: Tensor):
+ self.accumulate_step += 1
+
+ if self.is_torch_ddp:
+ no_sync = self.accumulate_step < self.accumulate_size
+ with conditional_context(self.model.no_sync(), enable=no_sync):
+ scaled_loss = loss / self.accumulate_size
+ self.optim.backward(scaled_loss)
+ else:
+ scaled_loss = loss / self.accumulate_size
+ self.optim.backward(scaled_loss)
+
+ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
+ no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
+
+ if no_sync:
+ with self.model.no_sync():
+ self.optim.backward_by_grad(tensor, grad)
+ else:
+ self.optim.backward_by_grad(tensor, grad)
+
+
+class GradAccumDataloader():
+
+ def __init__(self, dataloader: Iterable, accumulate_size: int) -> None:
+ self.dataloader = dataloader
+ self.consume_remain_data = not isinstance(dataloader, DataLoader)
+ self.steps_per_epoch = len(dataloader) - len(dataloader) % accumulate_size
+
+ def __getattr__(self, __name: str) -> Any:
+ return getattr(self.dataloader, __name)
+
+ def __len__(self):
+ return self.steps_per_epoch
+
+ def __iter__(self):
+ self._cur_step = 0
+ self._dataiter = iter(self.dataloader)
+ return self
+
+ def __next__(self) -> Any:
+ if self._cur_step < self.steps_per_epoch:
+ self._cur_step += 1
+
+ if self._cur_step == self.steps_per_epoch and self.consume_remain_data:
+ # this is to handle non standard pytorch dataloader
+ # such as dali dataloader
+ while True:
+ try:
+ _ = next(self._dataiter)
+ except StopIteration:
+ break
+ return next(self._dataiter)
+ else:
+ raise StopIteration
+
+
+class GradAccumLrSchedulerByStep(_LRScheduler):
+
+ def __init__(self, lr_scheduler: _LRScheduler, accumulate_size: int) -> None:
+ self.lr_scheduler = lr_scheduler
+ self.accumulate_size = accumulate_size
+ self.accumulate_step = 0
+
+ @staticmethod
+ def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int):
+ return len(dataloader) // accumulate_size
+
+ def __getattr__(self, __name: str) -> Any:
+ return getattr(self.lr_scheduler, __name)
+
+ def step(self, *args, **kwargs):
+ self.accumulate_step += 1
+ if self.accumulate_step < self.accumulate_size:
+ pass
+ else:
+ self.accumulate_step = 0
+ self.lr_scheduler.step(*args, **kwargs)
+
+ def get_lr(self):
+ return self.lr_scheduler.get_lr()
+
+ def get_last_lr(self):
+ return self.lr_scheduler.get_last_lr()
+
+ def print_lr(self, *args, **kwargs):
+ self.lr_scheduler.print_lr(*args, **kwargs)
+
+ def state_dict(self) -> dict:
+ return self.lr_scheduler.state_dict()
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ self.lr_scheduler.load_state_dict(state_dict)
+
+
+class GradAccumGradientHandler():
+
+ def __init__(self, grad_handler: BaseGradientHandler, accumulate_size: int) -> None:
+ assert isinstance(grad_handler, BaseGradientHandler), \
+ f'expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}'
+ self.grad_handler = grad_handler
+ self.accumulate_size = accumulate_size
+ self.accumulate_step = 0
+
+ def handle_gradient(self):
+ self.accumulate_step += 1
+ if self.accumulate_step < self.accumulate_size:
+ pass
+ else:
+ self.accumulate_step = 0
+ self.grad_handler.handle_gradient()
diff --git a/colossalai/utils/memory.py b/colossalai/utils/memory.py
index b47b4099d..904ec894b 100644
--- a/colossalai/utils/memory.py
+++ b/colossalai/utils/memory.py
@@ -8,19 +8,28 @@ import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
def bytes_to_GB(val, decimal=2):
'''A byte-to-Gigabyte converter, defaultly using binary notation.
- :param val: X bytes to convert
- :return: X' Gb
+ :param val: X bytes to convert
+ :return: X' GB
'''
return round(val / (1024 * 1024 * 1024), decimal)
-def report_memory_usage(message):
+def bytes_to_MB(val, decimal=2):
+ '''A byte-to-Megabyte converter, defaultly using binary notation.
+
+ :param val: X bytes to convert
+ :return: X' MB
+ '''
+ return round(val / (1024 * 1024), decimal)
+
+
+def report_memory_usage(message, logger=None, report_cpu=False):
'''Calculate and print RAM usage (in GB)
:param message: a prefix message to add in the log
@@ -30,19 +39,24 @@ def report_memory_usage(message):
if not gpc.is_initialized(ParallelMode.GLOBAL):
raise EnvironmentError("No distributed environment is initialized")
- # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
- gc.collect()
- vm_stats = psutil.virtual_memory()
- vm_used = bytes_to_GB(vm_stats.total - vm_stats.available)
+ gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated())
+ gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated())
+ gpu_cached = bytes_to_MB(torch.cuda.memory_reserved())
+ gpu_max_cached = bytes_to_MB(torch.cuda.max_memory_reserved())
- gpu_allocated = bytes_to_GB(torch.cuda.memory_allocated())
- gpu_max_allocated = bytes_to_GB(torch.cuda.max_memory_allocated())
- gpu_cached = bytes_to_GB(torch.cuda.memory_cached())
- gpu_max_cached = bytes_to_GB(torch.cuda.max_memory_cached())
+ full_log = f"{message} - GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, \
+ cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
- get_global_dist_logger().info(
- f"{message} - GPU: allocated {gpu_allocated}GB, max allocated {gpu_max_allocated}GB, cached: {gpu_cached} GB, "
- f"max cached: {gpu_max_cached}GB, CPU Virtual Memory: used = {vm_used}GB, percent = {vm_stats.percent}%")
+ if report_cpu:
+ # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
+ gc.collect()
+ vm_stats=psutil.virtual_memory()
+ vm_used=bytes_to_MB(vm_stats.total - vm_stats.available)
+ full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%"
+
+ if logger is None:
+ logger = get_dist_logger()
+ logger.info(full_log)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
diff --git a/colossalai/nn/multi_tensor_apply/__init__.py b/colossalai/utils/multi_tensor_apply/__init__.py
similarity index 100%
rename from colossalai/nn/multi_tensor_apply/__init__.py
rename to colossalai/utils/multi_tensor_apply/__init__.py
diff --git a/colossalai/nn/multi_tensor_apply/multi_tensor_apply.py b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py
similarity index 100%
rename from colossalai/nn/multi_tensor_apply/multi_tensor_apply.py
rename to colossalai/utils/multi_tensor_apply/multi_tensor_apply.py
diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py
index a516592dd..bc0205344 100644
--- a/colossalai/utils/timer.py
+++ b/colossalai/utils/timer.py
@@ -1,8 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-
import time
-
from .cuda import synchronize
@@ -10,7 +8,6 @@ class Timer:
'''
A timer object which helps to log the execution times, and provides different tools to assess the times.
'''
-
def __init__(self):
self._started = False
self._start_time = time.time()
@@ -31,7 +28,6 @@ class Timer:
def stop(self, keep_in_history: bool = False):
'''Stop the timer and record the start-stop time interval.
-
:param keep_in_history: whether does it record into history each start-stop interval, defaults to False
:type keep_in_history: bool, optional
:return: start-stop interval
@@ -48,7 +44,6 @@ class Timer:
def get_history_mean(self):
'''mean of all history start-stop time intervals.
-
:return: mean of time intervals
:rtype: int
'''
@@ -56,7 +51,6 @@ class Timer:
def get_history_sum(self):
'''add up all the start-stop time intervals.
-
:return: sum of time intervals
:rtype: int
'''
@@ -64,7 +58,6 @@ class Timer:
def get_elapsed_time(self):
'''return the last start-stop time interval. *use it only when timer is not in progress*
-
:return: the last time interval
:rtype: int
'''
@@ -89,7 +82,6 @@ class MultiTimer:
def start(self, name: str):
'''Start namely one of the timers
-
:param name: timer's key
:type name: str
'''
@@ -100,7 +92,6 @@ class MultiTimer:
def stop(self, name: str, keep_in_history: bool):
'''Stop namely one of the timers.
-
:param name: timer's key
:param keep_in_history: whether does it record into history each start-stop interval
:type keep_in_history: bool
@@ -112,7 +103,6 @@ class MultiTimer:
def get_timer(self, name):
'''Get timer by its name (from multitimer)
-
:param name: timer's key
:return: timer with the name you give correctly
:rtype: Timer
@@ -121,7 +111,6 @@ class MultiTimer:
def reset(self, name=None):
'''Reset timers.
-
:param name: if name is designated, the named timer will be reset and others will not, defaults to None
'''
if self._on:
@@ -132,7 +121,6 @@ class MultiTimer:
timer.reset()
def is_on(self):
-
return self._on
def set_status(self, mode: bool):
@@ -140,4 +128,4 @@ class MultiTimer:
def __iter__(self):
for name, timer in self._timers.items():
- yield name, timer
+ yield name, timer
\ No newline at end of file
diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py
new file mode 100644
index 000000000..8fe3dcab9
--- /dev/null
+++ b/colossalai/zero/__init__.py
@@ -0,0 +1,28 @@
+import torch.nn as nn
+from torch.optim import Optimizer
+from colossalai.amp.naive_amp import NaiveAMPModel
+from colossalai.utils import is_no_pp_or_last_stage
+
+from .zero_redundancy_optimizer_level_2 import ZeroRedundancyOptimizer_Level_2
+from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
+
+
+def convert_to_zero(model: nn.Module,
+ optimizer: Optimizer,
+ level: int,
+ zero_config):
+ assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
+
+ if is_no_pp_or_last_stage():
+ model = NaiveAMPModel(model, output_to_fp32=True)
+ else:
+ model = NaiveAMPModel(model, output_to_fp32=False)
+
+ if level == 2:
+ optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)
+ else:
+ optimizer = ZeroRedundancyOptimizer_Level_3(init_optimizer=optimizer, module=model, **zero_config)
+ return model, optimizer
+
+
+__all__ = ['convert_to_zero', 'ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']
diff --git a/colossalai/nn/optimizer/loss_scaler.py b/colossalai/zero/loss_scaler.py
similarity index 100%
rename from colossalai/nn/optimizer/loss_scaler.py
rename to colossalai/zero/loss_scaler.py
diff --git a/colossalai/nn/optimizer/zero_redundancy_optimizer_level_2.py b/colossalai/zero/zero_redundancy_optimizer_level_2.py
similarity index 99%
rename from colossalai/nn/optimizer/zero_redundancy_optimizer_level_2.py
rename to colossalai/zero/zero_redundancy_optimizer_level_2.py
index 1a57c5876..f022aaa6f 100644
--- a/colossalai/nn/optimizer/zero_redundancy_optimizer_level_2.py
+++ b/colossalai/zero/zero_redundancy_optimizer_level_2.py
@@ -21,11 +21,10 @@ from torch.distributed.distributed_c10d import _get_global_rank
from torch.optim import Optimizer
from colossalai.core import global_context as gpc
-from colossalai.registry import OPTIMIZER_WRAPPERS
from colossalai.utils import report_memory_usage
-from ._utils import is_model_parallel_parameter
+from colossalai.utils.common import is_model_parallel_parameter
from .loss_scaler import LossScaler, DynamicLossScaler
-from ...context.parallel_mode import ParallelMode
+from colossalai.context import ParallelMode
# Toggle this to true to enable correctness test
# with gradient partitioning and without
@@ -74,7 +73,6 @@ def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")
-@OPTIMIZER_WRAPPERS.register_module
class ZeroRedundancyOptimizer_Level_2(Optimizer):
"""
ZeroRedundancyOptimizer_Level_2 designed to reduce the memory footprint
@@ -252,7 +250,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
self.nccl_start_alignment_factor = 2
assert (
- allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
+ allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
self.all_reduce_print = False
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
@@ -760,7 +758,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
elif start_index > current_index and start_index < (current_index +
param_size):
assert (
- first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
+ first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i],
@@ -804,7 +802,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (
- 100.0 * elem_count) // self.reduce_bucket_size
+ 100.0 * elem_count) // self.reduce_bucket_size
if self.verbose:
report_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
@@ -1492,7 +1490,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
params_in_partition.append(tensor)
assert (
- first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
+ first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
@@ -1527,6 +1525,11 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
op=op,
group=self.model_parallel_group)
+ def clip_grad_norm(self, *args, **kwargs):
+ # dummy function to retain the same function interface
+ # as ColossalaiOptimizer for compatibility
+ pass
+
def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
@@ -1800,7 +1803,7 @@ class ZeroRedundancyOptimizer_Level_2(Optimizer):
num_elements = shard_size
assert shard_size * \
- num_shards <= partitioned_params[partition_id].numel()
+ num_shards <= partitioned_params[partition_id].numel()
for shard_id in range(num_shards):
@@ -2249,7 +2252,7 @@ def estimate_zero2_model_states_mem_needs(total_params,
if cpu_offload:
gpu_mem = 2 * total_params
cpu_mem = total_params * \
- max(4 * total_gpus, 16) * additional_buffer_factor
+ max(4 * total_gpus, 16) * additional_buffer_factor
else:
gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor
diff --git a/colossalai/nn/optimizer/zero_redundancy_optimizer_level_3.py b/colossalai/zero/zero_redundancy_optimizer_level_3.py
similarity index 99%
rename from colossalai/nn/optimizer/zero_redundancy_optimizer_level_3.py
rename to colossalai/zero/zero_redundancy_optimizer_level_3.py
index 4e54f3cd3..cf281d6b4 100644
--- a/colossalai/nn/optimizer/zero_redundancy_optimizer_level_3.py
+++ b/colossalai/zero/zero_redundancy_optimizer_level_3.py
@@ -28,10 +28,9 @@ from torch.distributed.distributed_c10d import _get_global_rank
from torch.optim import Optimizer
from colossalai.core import global_context as gpc
-from colossalai.registry import OPTIMIZER_WRAPPERS
from colossalai.utils import report_memory_usage
from .loss_scaler import LossScaler, DynamicLossScaler
-from ...context.parallel_mode import ParallelMode
+from colossalai.context import ParallelMode
# Toggle this to true to enable correctness test
# with gradient partitioning and without
@@ -412,7 +411,7 @@ class PartitionedParameterCoordinator(object):
)
params_to_fetch = [
param for _,
- param in sub_module.named_parameters(recurse=False)
+ param in sub_module.named_parameters(recurse=False)
]
# print([n for n,p in sub_module.named_parameters(recurse=False)])
@@ -422,7 +421,7 @@ class PartitionedParameterCoordinator(object):
)
params_to_fetch += [
param for _,
- param in sub_module.ds_external_parameters()
+ param in sub_module.ds_external_parameters()
]
# for _, param in sub_module.named_parameters(recurse=False):
for param in params_to_fetch:
@@ -474,14 +473,14 @@ class PartitionedParameterCoordinator(object):
)
params_to_release = [
param for _,
- param in sub_module.named_parameters(recurse=False)
+ param in sub_module.named_parameters(recurse=False)
]
if hasattr(sub_module, 'ds_external_parameters'):
# print_rank_0(f"Releasing external parameters {sub_module.ds_external_parameters()}")
params_to_release += [
param for _,
- param in sub_module.ds_external_parameters()
+ param in sub_module.ds_external_parameters()
]
# for _, param in sub_module.named_parameters(recurse=False):
@@ -604,7 +603,6 @@ class PostBackwardFunction(torch.autograd.Function):
INITIAL_MICRO_STEP_ID = -1
-@OPTIMIZER_WRAPPERS.register_module
class ZeroRedundancyOptimizer_Level_3(Optimizer):
"""
ZeroRedundancyOptimizer_Level_3 designed to reduce the memory footprint
@@ -718,7 +716,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
self.offload_optimizer_pin_memory = offload_optimizer_config[
OFFLOAD_OPTIMIZER_PIN_MEMORY]
self.swap_optimizer = offload_optimizer_config[
- OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE
+ OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE
self.offload_optimizer_fast_init = offload_optimizer_config[
OFFLOAD_OPTIMIZER_FAST_INIT]
@@ -733,7 +731,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
self.offload_param_pin_memory = offload_param_config[
OFFLOAD_PARAM_PIN_MEMORY]
self.params_in_nvme_and_cpu = offload_param_config[
- OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE
+ OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE
self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU]
if self.verbose:
print_rank_0(
@@ -1360,7 +1358,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
if self.params_in_nvme_and_cpu and tensor is None:
num_swap_from_nvme_partitions += 1
swap_from_nvme_memory_usage += (
- fp32_element_size * num_elements)
+ fp32_element_size * num_elements)
if self.offload_optimizer_fast_init:
sub_group_partitions = self._get_sub_group_partitions(
i)
@@ -1380,7 +1378,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
else:
num_swap_from_cpu_partitions += 1
swap_from_cpu_memory_usage += (
- fp32_element_size * num_elements)
+ fp32_element_size * num_elements)
swappable_fp32_tensors.append(
self.fp32_partitioned_groups_flat[i])
swappable_fp16_src_tensors.append(
@@ -1944,7 +1942,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
elif start_index > current_index and start_index < (current_index +
param_size):
assert (
- first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
+ first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i],
@@ -2003,7 +2001,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (
- 100.0 * elem_count) // self.reduce_bucket_size
+ 100.0 * elem_count) // self.reduce_bucket_size
report_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}")
@@ -2200,7 +2198,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
if self.offload_optimizer:
allocate_grads_in_partition = self.grads_in_partition is None \
- and self.gradient_accumulation_steps > 1
+ and self.gradient_accumulation_steps > 1
else:
allocate_grads_in_partition = self.grads_in_partition is None
@@ -2308,7 +2306,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
self.partition_previous_reduced_grads()
params_to_reduce = [param for i, param,
- param_id in self.params_in_ipg_bucket]
+ param_id in self.params_in_ipg_bucket]
# print(f"Params in ipg bucket {self.params_in_ipg_bucket}")
# print(f"Reducing {[(debug_param2name_id_shape(param), param.grad) for param in params_to_reduce]}")
# exit(0)
@@ -2522,7 +2520,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
params_in_partition.append(tensor)
assert (
- first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
+ first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
@@ -2557,6 +2555,11 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
op=op,
group=self.model_parallel_group)
+ def clip_grad_norm(self, *args, **kwargs):
+ # dummy function to retain the same function interface
+ # as ColossalaiOptimizer for compatibility
+ pass
+
def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
@@ -2824,7 +2827,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
self.optimizer_swapper.swap_out_optimizer_state(
parameter=self.fp32_partitioned_groups_flat[sub_group_id],
async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is
- not None)
+ not None)
self.stop_timers([OPTIMIZER_SWAP_OUT_STATE])
if self.verbose:
@@ -3175,7 +3178,7 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
individual_tensors = self.unflatten(
padded_flattened_tensor, group_tensors)
lean_lengths = [t.numel() - pad for t,
- pad in zip(group_tensors, paddings)]
+ pad in zip(group_tensors, paddings)]
lean_tensors = [t[:len]
for t, len in zip(individual_tensors, lean_lengths)]
# print()(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}')
diff --git a/configs/vit/vit_2d.py b/configs/vit/vit_2d.py
index f36a03acc..b771b583e 100644
--- a/configs/vit/vit_2d.py
+++ b/configs/vit/vit_2d.py
@@ -11,7 +11,7 @@ DIM = 512
NUM_ATTENTION_HEADS = 2
SUMMA_DIM = 2
NUM_CLASSES = 10
-DEPTH = 1
+DEPTH = 6
NUM_EPOCHS = 60
train_data = dict(
@@ -30,6 +30,7 @@ train_data = dict(
),
dataloader=dict(
batch_size=BATCH_SIZE,
+ drop_last=True,
pin_memory=True,
shuffle=True,
)
@@ -136,14 +137,14 @@ hooks = [
warmup_steps=5
)
),
- dict(type='TensorboardHook', log_dir='./tb_logs'),
+ # dict(type='TensorboardHook', log_dir='./tb_logs'),
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
]
parallel = dict(
pipeline=dict(size=1),
- tensor=dict(size=1, mode='2d'),
+ tensor=dict(size=4, mode='2d'),
)
# for fp16 training
diff --git a/docs/run_demo.md b/docs/run_demo.md
index 6d8c5b49a..2b0c4bdf3 100644
--- a/docs/run_demo.md
+++ b/docs/run_demo.md
@@ -32,13 +32,13 @@ realizes the training process.
```python
import colossalai
from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
def run_trainer():
engine, train_dataloader, test_dataloader = colossalai.initialize()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
diff --git a/docs/run_demo_zh.md b/docs/run_demo_zh.md
index 54839760d..5eadef6f2 100644
--- a/docs/run_demo_zh.md
+++ b/docs/run_demo_zh.md
@@ -24,13 +24,13 @@ HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./examples/run_trainer.p
```python
import colossalai
from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
def run_trainer():
engine, train_dataloader, test_dataloader = colossalai.initialize()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
trainer = Trainer(engine=engine,
diff --git a/docs/trainer_engine.md b/docs/trainer_engine.md
index 88b872826..c2abf1808 100644
--- a/docs/trainer_engine.md
+++ b/docs/trainer_engine.md
@@ -36,7 +36,7 @@ from colossalai.engine import Engine
model = models.resnet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
-schedule = colossalai.engine.NoPipelineSchedule()
+schedule = colossalai.engine.NonPipelineSchedule()
MyEngine = Engine(
model=model,
diff --git a/docs/trainer_engine_zh.md b/docs/trainer_engine_zh.md
index 737d6745b..5729a0599 100644
--- a/docs/trainer_engine_zh.md
+++ b/docs/trainer_engine_zh.md
@@ -31,7 +31,7 @@ model = models.resnet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model)
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
-schedule = colossalai.engine.NoPipelineSchedule()
+schedule = colossalai.engine.NonPipelineSchedule()
MyEngine = Engine(
model=model,
diff --git a/examples/colossal_cifar_demo.ipynb b/examples/colossal_cifar_demo.ipynb
index 221707bbb..266fd2543 100644
--- a/examples/colossal_cifar_demo.ipynb
+++ b/examples/colossal_cifar_demo.ipynb
@@ -1,20 +1,4 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "colossal_cifar_demo.ipynb",
- "provenance": []
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- },
- "accelerator": "GPU"
- },
"cells": [
{
"cell_type": "markdown",
@@ -27,6 +11,7 @@
},
{
"cell_type": "code",
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -34,14 +19,10 @@
"id": "vP7LvCpG23a2",
"outputId": "b37f7203-8a02-4736-c527-603f2bb34d7d"
},
- "source": [
- "!pip install ColossalAI deepspeed"
- ],
- "execution_count": null,
"outputs": [
{
- "output_type": "stream",
"name": "stdout",
+ "output_type": "stream",
"text": [
"Requirement already satisfied: ColossalAI in /usr/local/lib/python3.7/dist-packages (0.1)\n",
"Requirement already satisfied: deepspeed in /usr/local/lib/python3.7/dist-packages (0.5.4)\n",
@@ -60,10 +41,14 @@
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from triton->deepspeed) (3.3.0)\n"
]
}
+ ],
+ "source": [
+ "!pip install ColossalAI deepspeed"
]
},
{
"cell_type": "code",
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -71,24 +56,23 @@
"id": "UVKEurtS4SFS",
"outputId": "99fb6050-5da7-4f27-b4eb-9b3ccf830efb"
},
- "source": [
- "import colossalai\n",
- "from colossalai.engine import Engine, NoPipelineSchedule\n",
- "from colossalai.trainer import Trainer\n",
- "from colossalai.context import Config\n",
- "import torch"
- ],
- "execution_count": null,
"outputs": [
{
- "output_type": "stream",
"name": "stdout",
+ "output_type": "stream",
"text": [
"Please install apex to use FP16 Optimizer\n",
"Apex should be installed to use the FP16 optimizer\n",
"apex is required for mixed precision training\n"
]
}
+ ],
+ "source": [
+ "import colossalai\n",
+ "from colossalai.engine import Engine, NonPipelineSchedule\n",
+ "from colossalai.trainer import Trainer\n",
+ "from colossalai.context import Config\n",
+ "import torch"
]
},
{
@@ -102,6 +86,7 @@
},
{
"cell_type": "code",
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -109,6 +94,28 @@
"id": "8yF7Lc-K7NAS",
"outputId": "01312349-a8b0-4de4-9103-7d1b48e6cc36"
},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,596 INFO: Added key: store_based_barrier_key:1 to store for rank: 0\n",
+ "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,598 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
+ "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,602 INFO: Added key: store_based_barrier_key:2 to store for rank: 0\n",
+ "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,605 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
+ "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,608 INFO: Added key: store_based_barrier_key:3 to store for rank: 0\n",
+ "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,610 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "process rank 0 is bound to device 0\n",
+ "initialized seed on rank 0, numpy: 1024, python random: 1024, ParallelMode.DATA: 1024, ParallelMode.TENSOR: 1124,the default parallel seed is ParallelMode.DATA.\n"
+ ]
+ }
+ ],
"source": [
"parallel_cfg = Config(dict(parallel=dict(\n",
" data=dict(size=1),\n",
@@ -121,29 +128,6 @@
" host='127.0.0.1',\n",
" port=8888,\n",
" backend='nccl')"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,596 INFO: Added key: store_based_barrier_key:1 to store for rank: 0\n",
- "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,598 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
- "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,602 INFO: Added key: store_based_barrier_key:2 to store for rank: 0\n",
- "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,605 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n",
- "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,608 INFO: Added key: store_based_barrier_key:3 to store for rank: 0\n",
- "colossalai - torch.distributed.distributed_c10d - 2021-10-15 03:27:51,610 INFO: Rank 0: Completed store-based barrier for 1 nodes.\n"
- ]
- },
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "process rank 0 is bound to device 0\n",
- "initialized seed on rank 0, numpy: 1024, python random: 1024, ParallelMode.DATA: 1024, ParallelMode.TENSOR: 1124,the default parallel seed is ParallelMode.DATA.\n"
- ]
- }
]
},
{
@@ -157,13 +141,24 @@
},
{
"cell_type": "code",
+ "execution_count": null,
"metadata": {
- "id": "ZyGhyD47-dUY",
"colab": {
"base_uri": "https://localhost:8080/"
},
+ "id": "ZyGhyD47-dUY",
"outputId": "98bbf2d1-a1c4-4bb4-b6df-600777b1e8f5"
},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Files already downloaded and verified\n",
+ "Files already downloaded and verified\n"
+ ]
+ }
+ ],
"source": [
"transform_cfg = [\n",
" dict(type='ToTensor'),\n",
@@ -179,17 +174,6 @@
"\n",
"testset = colossalai.nn.data.CIFAR10Dataset(transform_cfg, root='./data', train=False)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Files already downloaded and verified\n",
- "Files already downloaded and verified\n"
- ]
- }
]
},
{
@@ -203,9 +187,11 @@
},
{
"cell_type": "code",
+ "execution_count": null,
"metadata": {
"id": "cQ_y7lBG09LS"
},
+ "outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
@@ -232,9 +218,7 @@
"\n",
"\n",
"model = Net().cuda()"
- ],
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "markdown",
@@ -247,6 +231,7 @@
},
{
"cell_type": "code",
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -254,6 +239,18 @@
"id": "YtaDoCax1BCf",
"outputId": "b33b1641-03d8-4597-c8c2-1a4c1d61e9b0"
},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "colossalai - rank_0 - 2021-10-15 03:27:56,018 WARNING: No gradient handler is set up, please make sure you do not need to all-reduce the gradients after a training step.\n",
+ "colossalai - rank_0 - 2021-10-15 03:27:56,024 INFO: build LogMetricByEpochHook for train, priority = 1\n",
+ "colossalai - rank_0 - 2021-10-15 03:27:56,026 INFO: build LossHook for train, priority = 10\n",
+ "colossalai - rank_0 - 2021-10-15 03:27:56,029 INFO: build AccuracyHook for train, priority = 10\n"
+ ]
+ }
+ ],
"source": [
"import torch.optim as optim\n",
"\n",
@@ -270,19 +267,6 @@
"trainer = Trainer(engine=engine,\n",
" hooks_cfg=[dict(type='LossHook'), dict(type='LogMetricByEpochHook'), dict(type='AccuracyHook')],\n",
" verbose=True)"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "colossalai - rank_0 - 2021-10-15 03:27:56,018 WARNING: No gradient handler is set up, please make sure you do not need to all-reduce the gradients after a training step.\n",
- "colossalai - rank_0 - 2021-10-15 03:27:56,024 INFO: build LogMetricByEpochHook for train, priority = 1\n",
- "colossalai - rank_0 - 2021-10-15 03:27:56,026 INFO: build LossHook for train, priority = 10\n",
- "colossalai - rank_0 - 2021-10-15 03:27:56,029 INFO: build AccuracyHook for train, priority = 10\n"
- ]
- }
]
},
{
@@ -296,6 +280,7 @@
},
{
"cell_type": "code",
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -303,22 +288,10 @@
"id": "w-J3IP-J1sfx",
"outputId": "bdb76939-04f1-4124-ce5e-3af44c0d902c"
},
- "source": [
- "num_epochs = 10\n",
- "test_interval = 1\n",
- "trainer.fit(\n",
- " train_dataloader=trainloader,\n",
- " test_dataloader=testloader,\n",
- " max_epochs=num_epochs,\n",
- " display_progress=True,\n",
- " test_interval=test_interval\n",
- " )"
- ],
- "execution_count": null,
"outputs": [
{
- "output_type": "stream",
"name": "stderr",
+ "output_type": "stream",
"text": [
"[Epoch 0 train]: 0%| | 0/391 [00:00, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n",
@@ -364,7 +337,34 @@
"colossalai - rank_0 - 2021-10-15 03:30:57,332 INFO: Testing - Epoch 10 - LogMetricByEpochHook: Loss = 1.41242, Accuracy = 0.48500\n"
]
}
+ ],
+ "source": [
+ "num_epochs = 10\n",
+ "test_interval = 1\n",
+ "trainer.fit(\n",
+ " train_dataloader=trainloader,\n",
+ " test_dataloader=testloader,\n",
+ " max_epochs=num_epochs,\n",
+ " display_progress=True,\n",
+ " test_interval=test_interval\n",
+ " )"
]
}
- ]
-}
\ No newline at end of file
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "colossal_cifar_demo.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/examples/run_trainer.py b/examples/run_trainer.py
index 080713d07..04cfb1c8f 100644
--- a/examples/run_trainer.py
+++ b/examples/run_trainer.py
@@ -3,13 +3,13 @@
import colossalai
from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
def run_trainer():
engine, train_dataloader, test_dataloader = colossalai.initialize()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
engine.schedule.data_sync = False
logger.info("engine is built", ranks=[0])
diff --git a/examples/vit-b16/README.md b/examples/vit-b16/README.md
index 4437ea060..6d5439523 100644
--- a/examples/vit-b16/README.md
+++ b/examples/vit-b16/README.md
@@ -1,54 +1,40 @@
# Overview
-A common way to speed up AI model training is to implement large-batch training with the help of data parallelism, but this requires expensive supercomputer clusters. In this example, we used a small server with only 4 GPUs to reproduce the large-scale pre-training of Vision Transformer (ViT) on ImageNet-1K in 14 hours.
+Here is an example of training ViT-B/16 on Imagenet-1K with batch size 32K.
+We use 8x NVIDIA A100 GPU in this example.
# How to run
-
-On a single server, you can directly use torch.distributed to start pre-training on multiple GPUs in parallel.
-
-```shell
-python -m torch.distributed.launch --nproc_per_node train_dali​.py --world_size --config
-```
-
-For scaling on a GPU cluster, you can use the [Slurm](https://slurm.schedmd.com/documentation.html) Workload Manager to start the following commands and get running environment information.
+Using [Slurm](https://slurm.schedmd.com/documentation.html):
```shell
srun python train_dali.py --local_rank=$SLURM_PROCID --world_size=$SLURM_NPROCS --host=$HOST --port=29500 --config=vit-b16.py
```
-# Experiments
-
-To facilitate more people to reproduce the experiments with large-scale data parallel, we pre-trained ViT-Base/32 in only 14.58 hours on a small server with 4 NVIDIA A100 GPUs using ImageNet-1K dataset with batch size 32K for 300 epochs maintaining accuracy. For more complex pre-training of ViT-Base/16 and ViT-Large/32, it also takes only 78.58 hours and 37.83 hours to complete. Since the server used in this example is not a standard NVIDIA DGX A100 supercomputing unit, perhaps a better acceleration can be obtained on more professional hardware.
+# Results


-As can be seen from the above figure, the ViT model eventually converges well after training 300 epochs. It is worth noting that, unlike the common small-batch training convergence process, the model performance has a temporary decline in the middle of the large-batch training process. This is due to the difficulty of convergence in large-batch training. As the number of iterations is reduced, a larger learning rate is needed to ensure the final convergence. Since we did not carefully adjust the parameters, perhaps other parameter settings could get better convergence.
-
# Details
-
`vit-b16.py`
-This is a [configuration file](https://colossalai.org/config.html) that defines training parameters used by Colossal-AI, such as model, dataset, training methods (optimizer, learning rate scheduler, number of epochs, etc.). The config content can be accessed through `gpc.config` in the program.
+It is a [config file](https://colossalai.org/config.html), which is used by ColossalAI to define all kinds of training arguments, such as the model, dataset, and training method (optimizer, lr_scheduler, epoch, etc.). You can access config content by `gpc.config`.
-In this example, we trained ViT-Base/16 for 300 epochs on the ImageNet-1K dataset. The batch size is expanded to 32K through data parallelism. Since only 4 A100 GPUs on one small server are used, and the GPU memory is limited, the batch size of 32K cannot be used directly. Therefore, the batch size used on each GPU is only 256, and the 256 batch size is equivalently expanded to 8K through gradient accumulation 32 times. Finally, data parallelism is used between 4 GPUs to achieve an equivalent batch size of 32K.
+In this example, we train the ViT-Base patch 16 model 300 epochs on ImageNet-1K. The batch size is set to 32K through data parallel (4K on each GPU from 16x gradient accumulation with batch size 256). Since the batch size is very large than common usage, leading to convergence difficulties, we use a
+large batch optimizer [LAMB](https://arxiv.org/abs/1904.00962), and we can scale the batch size to 32K with a little accuracy loss. The learning rate and weight decay of the optimizer are set to 1.8e-2 and 0.1, respectively. We use a linear warmup learning rate scheduler and warmup 150 epochs.
+We introduce FP16 mixed precision to accelerate training and use gradient clipping to help convergence.
+For simplicity and speed, we didn't apply `RandAug` and just used [Mixup](https://arxiv.org/abs/1710.09412) in data augmentation.
-Since the batch size of 32K far exceeds the use range of common optimizers and is difficult to train, we use the large-batch optimizer [LAMB](https://arxiv.org/abs/1904.00962) provided by Colossal-AI to achieve a better convergence. The learning rate and weight decay of [LAMB](https://arxiv.org/abs/1904.00962) are set to 1.8e-2 and 0.1, respectively. The learning rate scheduler uses a linear warmup strategy of 150 epochs. We also used FP16 mixed precision to speed up the training process, and introduced gradient clipping to help convergence. For simplicity and speed, we only use [Mixup](https://arxiv.org/abs/1710.09412) instead of `RandAug` in data augmentation.
+If you have enough computing resources, you can expand this example conveniently with data parallel on a very large scale without gradient accumulation, and finish the training process even within one hour.
-By tuning the parallelism, this example can be quickly deployed to a single server with several GPUs or to a large cluster with lots of nodes and GPUs. If there are enough computing resources to allow data parallel to be directly extended to hundreds or even thousands of GPUs, the training process of several days on a single A100 GPU can be shortened to less than half an hour.
`imagenet_dali_dataloader.py`
-To accelerate the training process, we use [DALI](https://github.com/NVIDIA/DALI) to read data and require the dataset to be in TFRecord format, which avoids directly reading a large number of raw image files and being limited by the efficiency of the file system.
+To accelerate the training process, we use [DALI](https://github.com/NVIDIA/DALI) as data loader. Note that it requires the dataset in TFRecord format, avoiding read raw images which reduces efficiency of the file system.
`train_dali.py`
-We call DALI in this file to read data and start the training process using Colossal-AI.
+We build the DALI data loader and train process using Colossal-AI here.
`mixup.py`
-Since Mixup is used as data augmentation, we define the loss function of Mixup here.
+Since we used Mixup, we define mixup loss in this file.
`hooks.py`
-We define hook functions that record running information to help debugging.
-
-# How to build TFRecords dataset
-
-As we use [DALI](https://github.com/NVIDIA/DALI) to read data, we use the TFRecords dataset instead of raw Imagenet dataset. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one.
-
+We also define useful hooks to log information help debugging.
diff --git a/examples/vit-b16/train_dali.py b/examples/vit-b16/train_dali.py
index fed39c3cc..31bd3be4d 100644
--- a/examples/vit-b16/train_dali.py
+++ b/examples/vit-b16/train_dali.py
@@ -3,7 +3,7 @@ import os
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import set_global_multitimer_status
from dataloader.imagenet_dali_dataloader import DaliDataloader
@@ -49,7 +49,7 @@ def main():
train_dataloader=build_dali_train,
test_dataloader=build_dali_test
)
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
set_global_multitimer_status(True)
timer = colossalai.utils.get_global_multitimer()
trainer = Trainer(engine=engine,
diff --git a/examples/vit-b16/vit-b16.py b/examples/vit-b16/vit-b16.py
index b23f78a30..ac51e226e 100755
--- a/examples/vit-b16/vit-b16.py
+++ b/examples/vit-b16/vit-b16.py
@@ -73,6 +73,6 @@ dali = dict(
engine = dict(
schedule=None,
gradient_handlers=None,
- gradient_accumulation=32,
+ gradient_accumulation=16,
gradient_clipping=1.0,
)
diff --git a/model_zoo/__init__.py b/model_zoo/__init__.py
index 9bec0d54b..e69de29bb 100644
--- a/model_zoo/__init__.py
+++ b/model_zoo/__init__.py
@@ -1,2 +0,0 @@
-from .vit import *
-from .mlp_mixer import *
diff --git a/model_zoo/bert/__init__.py b/model_zoo/bert/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/model_zoo/mlp_mixer/__init__.py b/model_zoo/mlp_mixer/__init__.py
index 8b48742eb..e69de29bb 100644
--- a/model_zoo/mlp_mixer/__init__.py
+++ b/model_zoo/mlp_mixer/__init__.py
@@ -1 +0,0 @@
-from .parallel_3d import *
diff --git a/model_zoo/vit/__init__.py b/model_zoo/vit/__init__.py
index 6e009854d..e69de29bb 100644
--- a/model_zoo/vit/__init__.py
+++ b/model_zoo/vit/__init__.py
@@ -1,2 +0,0 @@
-from .parallel_2d import *
-from .parallel_3d import *
diff --git a/model_zoo/vit/parallel_1d/vit.py b/model_zoo/vit/parallel_1d/vit.py
new file mode 100644
index 000000000..e471fed14
--- /dev/null
+++ b/model_zoo/vit/parallel_1d/vit.py
@@ -0,0 +1,208 @@
+import torch
+from torch import nn
+
+from colossalai import nn as col_nn
+from colossalai.context import ParallelMode
+from colossalai.registry import MODELS
+
+__all__ = [
+ 'VisionTransformer3D',
+ 'vit_tiny_1d_patch4_32',
+ 'vit_tiny_1d_patch16_224',
+ 'vit_tiny_1d_patch16_384',
+ 'vit_small_1d_patch16_224',
+ 'vit_small_1d_patch16_384',
+ 'vit_small_1d_patch32_224',
+ 'vit_small_1d_patch32_384',
+ 'vit_base_1d_patch16_224',
+ 'vit_base_1d_patch16_384',
+ 'vit_base_1d_patch32_224',
+ 'vit_base_1d_patch32_384',
+ 'vit_large_1d_patch16_224',
+ 'vit_large_1d_patch16_384',
+ 'vit_large_1d_patch32_224',
+ 'vit_large_1d_patch32_384',
+]
+
+
+class ViTBlock1D(nn.Module):
+ def __init__(self,
+ dim: int,
+ num_heads: int,
+ hidden_dim: int,
+ drop: float = 0.,
+ attn_drop: float = 0.,
+ drop_path: float = 0.):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
+ self.attn = col_nn.ViTSelfAttention1D(dim, num_heads, attn_drop, drop)
+ self.drop_path = col_nn.VanillaViTDropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
+ self.mlp = col_nn.ViTMLP1D(dim, 1, drop, 'gelu')
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+@MODELS.register_module
+class VisionTransformer1D(nn.Module):
+ def __init__(self,
+ img_size: int = 224,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ depth: int = 12,
+ num_heads: int = 12,
+ embed_dim: int = 768,
+ hidden_dim: int = 3072,
+ drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.):
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim
+
+ self.patch_embed = col_nn.ViTPatchEmbedding1D(
+ img_size,
+ patch_size,
+ in_chans,
+ embed_dim,
+ drop_rate,
+ )
+
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+ self.blocks = nn.Sequential(*[
+ ViTBlock1D(embed_dim, num_heads, hidden_dim,
+ drop_rate, attn_drop_rate, dpr[i])
+ for i in range(depth)
+ ])
+
+ self.norm = nn.LayerNorm(embed_dim, ParallelMode.PARALLEL_3D_INPUT,
+ ParallelMode.PARALLEL_3D_WEIGHT)
+
+ self.head = col_nn.ViTHead1D(hidden_dim, num_classes)
+ self.init_weights()
+
+ def init_weights(self):
+ pass
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ x = self.blocks(x)
+ x = self.norm(x)
+ x = self.head(x)
+ return x
+
+
+def _create_vit_model(**model_kwargs):
+ model = VisionTransformer1D(**model_kwargs)
+ return model
+
+
+@MODELS.register_module
+def vit_tiny_1d_patch4_32(**kwargs):
+ model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512,
+ depth=6, num_heads=8, hidden_dim=512, num_classes=10, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_tiny_1d_patch16_224(**kwargs):
+ model_kwargs = dict(patch_size=16, embed_dim=192,
+ depth=12, num_heads=3, hidden_dim=768, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_tiny_1d_patch16_384(**kwargs):
+ model_kwargs = dict(img_size=384, patch_size=16,
+ embed_dim=192, depth=12, num_heads=3, hidden_dim=768, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_small_1d_patch16_224(**kwargs):
+ model_kwargs = dict(patch_size=16, embed_dim=384,
+ depth=12, num_heads=6, hidden_dim=1536, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_small_1d_patch16_384(**kwargs):
+ model_kwargs = dict(img_size=384, patch_size=16,
+ embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_small_1d_patch32_224(**kwargs):
+ model_kwargs = dict(patch_size=32, embed_dim=384,
+ depth=12, num_heads=6, hidden_dim=1536, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_small_1d_patch32_384(**kwargs):
+ model_kwargs = dict(img_size=384, patch_size=32,
+ embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_base_1d_patch16_224(**kwargs):
+ model_kwargs = dict(patch_size=16, embed_dim=768,
+ depth=12, num_heads=12, hidden_dim=3072, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_base_1d_patch16_384(**kwargs):
+ model_kwargs = dict(img_size=384, patch_size=16,
+ embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_base_3d_patch32_224(**kwargs):
+ model_kwargs = dict(patch_size=32, embed_dim=768,
+ depth=12, num_heads=12, hidden_dim=3072, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_base_1d_patch32_384(**kwargs):
+ model_kwargs = dict(img_size=384, patch_size=32,
+ embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_large_3d_patch16_224(**kwargs):
+ model_kwargs = dict(patch_size=16, embed_dim=1024,
+ depth=24, num_heads=16, hidden_dim=4096, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_large_1d_patch16_384(**kwargs):
+ model_kwargs = dict(img_size=384, patch_size=16,
+ embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_large_1d_patch32_224(**kwargs):
+ model_kwargs = dict(patch_size=32, embed_dim=1024,
+ depth=24, num_heads=16, hidden_dim=4096, **kwargs)
+ return _create_vit_model(**model_kwargs)
+
+
+@MODELS.register_module
+def vit_large_1d_patch32_384(**kwargs):
+ model_kwargs = dict(img_size=384, patch_size=32,
+ embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
+ return _create_vit_model(**model_kwargs)
diff --git a/colossalai/nn/model/vision_transformer/vision_transformer.py b/model_zoo/vit/vision_transformer_from_config.py
similarity index 95%
rename from colossalai/nn/model/vision_transformer/vision_transformer.py
rename to model_zoo/vit/vision_transformer_from_config.py
index 98f5cae55..af1e32091 100644
--- a/colossalai/nn/model/vision_transformer/vision_transformer.py
+++ b/model_zoo/vit/vision_transformer_from_config.py
@@ -4,11 +4,11 @@
import torch
from colossalai.registry import MODELS
-from ..base_model import BaseModel
+from colossalai.nn.model.model_from_config import ModelFromConfig
@MODELS.register_module
-class VisionTransformerFromConfig(BaseModel):
+class VisionTransformerFromConfig(ModelFromConfig):
"""Vision Transformer from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_.
diff --git a/setup.py b/setup.py
index 8541b0a6c..f7684d4da 100644
--- a/setup.py
+++ b/setup.py
@@ -132,4 +132,4 @@ setup(
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
install_requires=install_requires,
-)
+)
\ No newline at end of file
diff --git a/tests/test_config/sample_config.py b/tests/test_config/sample_config.py
index e48c70e14..08ca10828 100644
--- a/tests/test_config/sample_config.py
+++ b/tests/test_config/sample_config.py
@@ -1,12 +1,10 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-import os
-from pathlib import Path
train_data = dict(
dataset=dict(
type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
+ root='/path/to/data',
download=True,
transform_pipeline=[
dict(type='RandomResizedCrop', size=224),
diff --git a/tests/test_context/test_2d_init.py b/tests/test_context/test_2d_init.py
index 24e0749ae..d373964f8 100644
--- a/tests/test_context/test_2d_init.py
+++ b/tests/test_context/test_2d_init.py
@@ -7,7 +7,7 @@ from pathlib import Path
import pytest
import torch.multiprocessing as mp
-from colossalai import init_dist
+from colossalai import launch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@@ -58,22 +58,22 @@ def check_2d_parallel_rank(rank):
assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 1
-def init_2d(local_rank, world_size, backend, port, host):
+def init_2d(rank, world_size, backend, port, host):
dist_args = dict(
config=CONFIG_PATH,
- local_rank=local_rank,
+ rank=rank,
world_size=world_size,
backend=backend,
port=port,
- host=host
+ host=host,
+ verbose=True
)
- init_dist(**dist_args)
-
- check_tensor_parallel_rank(local_rank)
- check_data_parallel_rank(local_rank)
- check_2d_parallel_rank(local_rank)
- check_pipeline_parallel_rank(local_rank)
+ launch(**dist_args)
+ check_tensor_parallel_rank(rank)
+ check_data_parallel_rank(rank)
+ check_2d_parallel_rank(rank)
+ check_pipeline_parallel_rank(rank)
gpc.destroy()
diff --git a/tests/test_context/test_2p5d_init.py b/tests/test_context/test_2p5d_init.py
index 26de7f7ff..c071d86e7 100644
--- a/tests/test_context/test_2p5d_init.py
+++ b/tests/test_context/test_2p5d_init.py
@@ -9,7 +9,7 @@ import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.initialize import init_dist
+from colossalai.initialize import launch
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute()
@@ -82,20 +82,21 @@ def check_2p5d_parallel_rank(rank):
assert xp_rank == i
-def init_2halfd(local_rank, world_size, backend, port, host):
+def init_2halfd(rank, world_size, backend, port, host):
dist_args = dict(
config=CONFIG_PATH,
- local_rank=local_rank,
+ rank=rank,
world_size=world_size,
backend=backend,
port=port,
- host=host
+ host=host,
+ verbose=True
)
- init_dist(**dist_args)
- check_data_parallel_rank(local_rank)
- check_pipeline_parallel_rank(local_rank)
- check_tensor_parallel_rank(local_rank)
- check_2p5d_parallel_rank(local_rank)
+ launch(**dist_args)
+ check_data_parallel_rank(rank)
+ check_pipeline_parallel_rank(rank)
+ check_tensor_parallel_rank(rank)
+ check_2p5d_parallel_rank(rank)
gpc.destroy()
diff --git a/tests/test_context/test_3d_init.py b/tests/test_context/test_3d_init.py
index 0fba98bff..a1c48a9b7 100644
--- a/tests/test_context/test_3d_init.py
+++ b/tests/test_context/test_3d_init.py
@@ -9,7 +9,7 @@ import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.initialize import init_dist
+from colossalai.initialize import launch
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_3d_init.py').absolute()
@@ -74,21 +74,21 @@ def check_3d_parallel_rank(rank):
assert op_rank == i
-def init_3d(local_rank, world_size, backend, port, host):
+def init_3d(rank, world_size, backend, port, host):
dist_args = dict(
config=CONFIG_PATH,
- local_rank=local_rank,
+ rank=rank,
world_size=world_size,
backend=backend,
port=port,
- host=host
+ host=host,
+ verbose=True
)
- init_dist(**dist_args)
- check_tensor_parallel_rank(local_rank)
- check_3d_parallel_rank(local_rank)
- check_data_parallel_rank(local_rank)
- check_pipeline_parallel_rank(local_rank)
- print('pass')
+ launch(**dist_args)
+ check_tensor_parallel_rank(rank)
+ check_3d_parallel_rank(rank)
+ check_data_parallel_rank(rank)
+ check_pipeline_parallel_rank(rank)
gpc.destroy()
diff --git a/tests/test_data/test_cifar10_dataset.py b/tests/test_data/test_cifar10_dataset.py
index 10b79dd03..569cea2ca 100644
--- a/tests/test_data/test_cifar10_dataset.py
+++ b/tests/test_data/test_cifar10_dataset.py
@@ -5,39 +5,50 @@ import os
from pathlib import Path
import pytest
+from torchvision import transforms
from torch.utils.data import DataLoader
-from colossalai.builder import build_dataset
+from colossalai.builder import build_dataset, build_transform
from colossalai.context import Config
-train_data = dict(
+TRAIN_DATA = dict(
dataset=dict(
- type='CIFAR10Dataset',
+ type='CIFAR10',
root=Path(os.environ['DATA']),
train=True,
- download=True,
- transform_pipeline=[
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=(0.5, 0.5, 0.5),
- std=(0.5, 0.5, 0.5))
- ]),
- dataloader=dict(batch_size=4, shuffle=True, num_workers=2)
+ download=True
+ ),
+ dataloader=dict(batch_size=4, shuffle=True, num_workers=2),
+ transform_pipeline=[
+ dict(type='ToTensor'),
+ dict(type='Normalize',
+ mean=(0.5, 0.5, 0.5),
+ std=(0.5, 0.5, 0.5)
+ )
+ ]
)
@pytest.mark.cpu
def test_cifar10_dataset():
- global train_data
- config = Config(train_data)
- dataset = build_dataset(config.dataset)
- dataloader = DataLoader(dataset=dataset, **config.dataloader)
+ config = Config(TRAIN_DATA)
+ dataset_cfg = config.dataset
+ dataloader_cfg = config.dataloader
+ transform_cfg = config.transform_pipeline
+
+ # build transform
+ transform_pipeline = [build_transform(cfg) for cfg in transform_cfg]
+ transform_pipeline = transforms.Compose(transform_pipeline)
+ dataset_cfg['transform'] = transform_pipeline
+
+ # build dataset
+ dataset = build_dataset(dataset_cfg)
+
+ # build dataloader
+ dataloader = DataLoader(dataset=dataset, **dataloader_cfg)
data_iter = iter(dataloader)
img, label = data_iter.next()
- assert isinstance(img, list) and isinstance(label, list), \
- f'expected the img and label to be list but got {type(img)} and {type(label)}'
-
if __name__ == '__main__':
test_cifar10_dataset()
diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py
index 056f0441a..2f2e275c4 100644
--- a/tests/test_data/test_data_parallel_sampler.py
+++ b/tests/test_data/test_data_parallel_sampler.py
@@ -12,54 +12,54 @@ import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import colossalai
-from colossalai.builder import build_dataset, build_data_sampler
-from colossalai.context.parallel_mode import ParallelMode
+from colossalai.builder import build_dataset, build_data_sampler, build_transform
+from torchvision import transforms
+from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc
+from colossalai.utils import get_dataloader
-CONFIG = dict(
- train_data=dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=True,
- download=True,
+CONFIG = Config(
+ dict(
+ train_data=dict(
+ dataset=dict(
+ type='CIFAR10',
+ root=Path(os.environ['DATA']),
+ train=True,
+ download=True,
+ ),
+ dataloader=dict(
+ batch_size=8,
+ ),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
),
- dataloader=dict(
- num_workers=2,
- batch_size=8,
- sampler=dict(
- type='DataParallelSampler',
- )
- )
- ),
- parallel=dict(
- pipeline=dict(size=1),
- tensor=dict(size=1, mode=None),
- ),
- seed=1024,
-)
+ parallel=dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=1, mode=None),
+ ),
+ seed=1024,
+ ))
-def run_data_sampler(local_rank, world_size):
+def run_data_sampler(rank, world_size):
dist_args = dict(
config=CONFIG,
- local_rank=local_rank,
+ rank=rank,
world_size=world_size,
backend='gloo',
port='29503',
host='localhost'
)
- colossalai.init_dist(**dist_args)
+ colossalai.launch(**dist_args)
print('finished initialization')
+ transform_pipeline = [build_transform(cfg) for cfg in gpc.config.train_data.transform_pipeline]
+ transform_pipeline = transforms.Compose(transform_pipeline)
+ gpc.config.train_data.dataset['transform'] = transform_pipeline
dataset = build_dataset(gpc.config.train_data.dataset)
- sampler_cfg = gpc.config.train_data.dataloader.pop('sampler')
- sampler = build_data_sampler(sampler_cfg, dataset)
- dataloader = DataLoader(dataset=dataset, sampler=sampler, **gpc.config.train_data.dataloader)
+ dataloader = get_dataloader(dataset, **gpc.config.train_data.dataloader)
data_iter = iter(dataloader)
img, label = data_iter.next()
img = img[0]
diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py
index 9cfd6c4fc..237c92b77 100644
--- a/tests/test_data/test_deterministic_dataloader.py
+++ b/tests/test_data/test_deterministic_dataloader.py
@@ -9,56 +9,70 @@ import pytest
import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
+from torchvision import transforms
from torch.utils.data import DataLoader
import colossalai
-from colossalai.builder import build_dataset
-from colossalai.context.parallel_mode import ParallelMode
+from colossalai.builder import build_dataset, build_transform
+from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc
-CONFIG = dict(
- train_data=dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=True,
- download=True,
+CONFIG = Config(
+ dict(
+ train_data=dict(
+ dataset=dict(
+ type='CIFAR10',
+ root=Path(os.environ['DATA']),
+ train=True,
+ download=True,
+ ),
+ dataloader=dict(
+ num_workers=2,
+ batch_size=2,
+ shuffle=True
+ ),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='RandomCrop', size=32),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
),
- dataloader=dict(
- num_workers=2,
- batch_size=2,
- shuffle=True
- )
- ),
- parallel=dict(
- pipeline=dict(size=1),
- tensor=dict(size=1, mode=None),
- ),
- seed=1024,
+ parallel=dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=1, mode=None),
+ ),
+ seed=1024,
+ )
)
-def run_data_sampler(local_rank, world_size):
+def run_data_sampler(rank, world_size):
dist_args = dict(
config=CONFIG,
- local_rank=local_rank,
+ rank=rank,
world_size=world_size,
backend='gloo',
port='29499',
host='localhost'
)
- colossalai.init_dist(**dist_args)
- gpc.set_seed()
-
+ colossalai.launch(**dist_args)
print('finished initialization')
- dataset = build_dataset(gpc.config.train_data.dataset)
- dataloader = DataLoader(dataset=dataset, **gpc.config.train_data.dataloader)
+ dataset_cfg = gpc.config.train_data.dataset
+ dataloader_cfg = gpc.config.train_data.dataloader
+ transform_cfg = gpc.config.train_data.transform_pipeline
+
+ # build transform
+ transform_pipeline = [build_transform(cfg) for cfg in transform_cfg]
+ transform_pipeline = transforms.Compose(transform_pipeline)
+ dataset_cfg['transform'] = transform_pipeline
+
+ # build dataset
+ dataset = build_dataset(dataset_cfg)
+
+ # build dataloader
+ dataloader = DataLoader(dataset=dataset, **dataloader_cfg)
+
data_iter = iter(dataloader)
img, label = data_iter.next()
img = img[0]
diff --git a/tests/test_data_pipeline_tensor_parallel/configs/vit_2d.py b/tests/test_data_pipeline_tensor_parallel/configs/vit_2d.py
deleted file mode 100644
index c97ed1804..000000000
--- a/tests/test_data_pipeline_tensor_parallel/configs/vit_2d.py
+++ /dev/null
@@ -1,150 +0,0 @@
-import os
-from pathlib import Path
-
-from colossalai.engine import AMP_TYPE
-
-BATCH_SIZE = 256
-IMG_SIZE = 32
-PATCH_SIZE = 4
-DIM = 512
-NUM_ATTENTION_HEADS = 8
-NUM_CLASSES = 10
-DEPTH = 6
-
-train_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- transform_pipeline=[
- dict(type='RandomCrop', size=IMG_SIZE, padding=4),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- )
-)
-
-test_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]
- ),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- # num_workers=1,
- )
-)
-
-optimizer = dict(
- type='Adam',
- lr=0.001,
- weight_decay=0
-)
-
-loss = dict(
- type='CrossEntropyLoss2D',
-)
-
-model = dict(
- type='VisionTransformerFromConfig',
- tensor_splitting_cfg=dict(
- type='ViTInputSplitter2D',
- ),
- embedding_cfg=dict(
- type='ViTPatchEmbedding2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- ),
- token_fusion_cfg=dict(
- type='ViTTokenFuser2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- drop_rate=0.1
- ),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- block_cfg=dict(
- type='ViTBlock',
- attention_cfg=dict(
- type='ViTSelfAttention2D',
- hidden_size=DIM,
- num_attention_heads=NUM_ATTENTION_HEADS,
- attention_dropout_prob=0.,
- hidden_dropout_prob=0.1,
- ),
- droppath_cfg=dict(
- type='VanillaViTDropPath',
- ),
- mlp_cfg=dict(
- type='ViTMLP2D',
- in_features=DIM,
- dropout_prob=0.1,
- mlp_ratio=1
- ),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- ),
- head_cfg=dict(
- type='ViTHead2D',
- hidden_size=DIM,
- num_classes=NUM_CLASSES,
- ),
- embed_dim=DIM,
- depth=DEPTH,
- drop_path_rate=0.,
-)
-
-parallel = dict(
- pipeline=dict(size=2),
- tensor=dict(size=4, mode='2d'),
-)
-
-fp16 = dict(
- mode=AMP_TYPE.PARALLEL,
-)
-
-engine = dict(
- schedule=dict(
- num_microbatches=2
- )
-)
-
-hooks = [
- dict(
- type='LRSchedulerHook',
- by_epoch=True,
- lr_scheduler_cfg=dict(
- type='LinearWarmupLR',
- warmup_steps=5
- )
- ),
-]
-num_epochs = 60
-
-logging = dict(
- root_path='test_vit_2d_log'
-)
diff --git a/tests/test_data_pipeline_tensor_parallel/configs/vit_2p5d.py b/tests/test_data_pipeline_tensor_parallel/configs/vit_2p5d.py
deleted file mode 100644
index fd9c89eb4..000000000
--- a/tests/test_data_pipeline_tensor_parallel/configs/vit_2p5d.py
+++ /dev/null
@@ -1,144 +0,0 @@
-import os
-from pathlib import Path
-
-BATCH_SIZE = 250
-IMG_SIZE = 32
-PATCH_SIZE = 4
-DIM = 512
-NUM_ATTENTION_HEADS = 8
-NUM_CLASSES = 10
-DEPTH = 6
-
-train_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- transform_pipeline=[
- dict(type='RandomCrop', size=IMG_SIZE, padding=4),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=0,
- shuffle=True
- )
-)
-
-test_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]
- ),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=0,
- shuffle=True
- )
-)
-
-optimizer = dict(
- type='Adam',
- lr=0.001,
- weight_decay=0
-)
-
-loss = dict(
- type='CrossEntropyLoss2p5D',
-)
-
-model = dict(
- type='VisionTransformerFromConfig',
- tensor_splitting_cfg=dict(
- type='ViTInputSplitter2p5D',
- ),
- embedding_cfg=dict(
- type='ViTPatchEmbedding2p5D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- ),
- token_fusion_cfg=dict(
- type='ViTTokenFuser2p5D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- drop_rate=0.1
- ),
- norm_cfg=dict(
- type='LayerNorm2p5D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- block_cfg=dict(
- type='ViTBlock',
- attention_cfg=dict(
- type='ViTSelfAttention2p5D',
- hidden_size=DIM,
- num_attention_heads=NUM_ATTENTION_HEADS,
- attention_dropout_prob=0.,
- hidden_dropout_prob=0.1,
- ),
- droppath_cfg=dict(
- type='VanillaViTDropPath',
- ),
- mlp_cfg=dict(
- type='ViTMLP2p5D',
- in_features=DIM,
- dropout_prob=0.1,
- mlp_ratio=1
- ),
- norm_cfg=dict(
- type='LayerNorm2p5D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- ),
- head_cfg=dict(
- type='ViTHead2p5D',
- hidden_size=DIM,
- num_classes=NUM_CLASSES,
- ),
- embed_dim=DIM,
- depth=DEPTH,
- drop_path_rate=0.,
-)
-
-parallel = dict(
- pipeline=dict(size=2),
- tensor=dict(size=4, depth=1, mode='2.5d'),
-)
-
-hooks = [
- dict(
- type='LRSchedulerHook',
- by_epoch=True,
- lr_scheduler_cfg=dict(
- type='LinearWarmupLR',
- warmup_steps=5
- )
- ),
-]
-
-engine = dict(
-schedule = dict(
- num_microbatches=2
-)
-)
-
-num_epochs = 60
diff --git a/tests/test_data_pipeline_tensor_parallel/run_cifar10_vit2d_with_pipeline.py b/tests/test_data_pipeline_tensor_parallel/run_cifar10_vit2d_with_pipeline.py
new file mode 100644
index 000000000..529fedf5a
--- /dev/null
+++ b/tests/test_data_pipeline_tensor_parallel/run_cifar10_vit2d_with_pipeline.py
@@ -0,0 +1,139 @@
+from pathlib import Path
+from colossalai.amp.amp_type import AMP_TYPE
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.logging import get_dist_logger
+import colossalai
+import torch
+import os
+from colossalai.builder import PipelineModelInitializer
+from colossalai.core import global_context as gpc
+from colossalai.utils import get_dataloader, MultiTimer
+from colossalai.nn.loss import CrossEntropyLoss2D
+from colossalai.trainer.metric import Accuracy2D
+from colossalai.trainer import metric, hooks, Trainer
+from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
+from colossalai.engine.schedule import PipelineSchedule
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+from colossalai.nn import LinearWarmupLR
+from tqdm import tqdm
+import vit_t_2d
+
+BATCH_SIZE = 16
+NUM_EPOCHS = 60
+WARMUP_EPOCHS = 5
+CONFIG = dict(
+ parallel=dict(
+ pipeline=2,
+ tensor=dict(size=4, mode='2d')
+ ),
+ fp16=dict(
+ mode=AMP_TYPE.TORCH
+ ),
+ gradient_accumulation=2
+)
+
+
+def main():
+ parser = colossalai.get_default_parser()
+ args = parser.parse_args()
+ colossalai.launch_from_slurm(config=CONFIG,
+ host=args.host,
+ port=29500)
+
+ logger = get_dist_logger()
+ # if gpc.get_global_rank() == 0:
+ # logger.log_to_file('./logs/cifar10_2d_vit',
+ # suffix='cifar10_2d_vit_ddp1_torch_amp_grad_accum_2_clip_grad_1', mode='w')
+
+ # build vit-t-32
+ initializer = PipelineModelInitializer(vit_t_2d.model_cfg, num_chunks=1)
+ model = initializer.initialize()
+
+ # build dataloaders
+ train_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.RandomCrop(size=32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
+ 0.2023, 0.1994, 0.2010]),
+ ]
+ )
+ )
+
+ test_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ train=False,
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
+ 0.2023, 0.1994, 0.2010]),
+ ]
+ )
+ )
+
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ add_sampler=True,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ )
+
+ test_dataloader = get_dataloader(dataset=test_dataset,
+ add_sampler=True,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ )
+
+ # build criterion
+ criterion = CrossEntropyLoss2D()
+
+ # optimizer
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
+
+ # lr_scheduler
+ steps_per_epoch = GradAccumLrSchedulerByStep.compute_effective_steps_per_epoch(train_dataloader, accumulate_size=2)
+ total_steps = steps_per_epoch * NUM_EPOCHS
+ warmup_steps = steps_per_epoch * WARMUP_EPOCHS
+ lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
+
+ engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(
+ model, optimizer, criterion, train_dataloader, test_dataloader, lr_scheduler)
+
+ timer = MultiTimer()
+
+ schedule = PipelineSchedule(num_microbatches=4)
+
+ trainer = Trainer(
+ engine=engine,
+ timer=timer,
+ logger=logger,
+ schedule=schedule
+ )
+
+ hook_list = [
+ hooks.LossHook(),
+ hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
+ hooks.Accuracy2DHook(),
+ hooks.LogMetricByEpochHook(logger),
+ ]
+
+ trainer.fit(
+ train_dataloader=train_dataloader,
+ epochs=NUM_EPOCHS,
+ test_dataloader=test_dataloader,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True
+ )
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/test_data_pipeline_tensor_parallel/test.sh b/tests/test_data_pipeline_tensor_parallel/test.sh
index 1c6012a52..0796e23cb 100644
--- a/tests/test_data_pipeline_tensor_parallel/test.sh
+++ b/tests/test_data_pipeline_tensor_parallel/test.sh
@@ -1,4 +1,3 @@
#!/usr/bin/env sh
-test_file=$1
-python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
+python run_cifar10_vit2d_with_pipeline.py --host $HOST
diff --git a/tests/test_data_pipeline_tensor_parallel/test_vit_2d/test_vit_2d.py b/tests/test_data_pipeline_tensor_parallel/test_vit_2d/test_vit_2d.py
deleted file mode 100644
index b68a58cea..000000000
--- a/tests/test_data_pipeline_tensor_parallel/test_vit_2d/test_vit_2d.py
+++ /dev/null
@@ -1,87 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from pathlib import Path
-
-import pytest
-import torch.autograd
-
-import colossalai
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.nn.layer._parallel_utilities import _gather
-
-CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
-
-
-def eval(engine, test_dataloader):
- engine.eval()
- accumulated_loss = 0
- correct_sum = 0
- total_sum = 0
- num_steps = len(test_dataloader)
- data_iter = iter(test_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
-
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- # loss = sum(loss)
- accumulated_loss += loss.detach().cpu().numpy()
-
- output = _gather(
- output,
- ParallelMode.PARALLEL_2D_ROW,
- 1
- )
- output = _gather(
- output,
- ParallelMode.PARALLEL_2D_COL,
- 0,
- )
- output = torch.argmax(output, dim=-1)
- correct = torch.sum(label == output)
- correct_sum += correct
- total_sum += label.size(0)
- avg_loss = accumulated_loss / num_steps
- return correct_sum, total_sum, avg_loss
-
-
-def train(engine, train_dataloader):
- engine.train()
- accumulated_loss = 0
- num_steps = len(train_dataloader)
- data_iter = iter(train_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
-
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- accumulated_loss += loss.detach().cpu().numpy()
- avg_loss = accumulated_loss / num_steps
- return avg_loss
-
-
-@pytest.mark.dist
-@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
-def test_2d_parallel_vision_transformer():
- # init dist
- engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
- logger = get_global_dist_logger()
-
- for epoch in range(gpc.config.num_epochs):
- train_loss = train(engine, train_dataloader)
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- logger.info(f'epoch {epoch} - train loss: {train_loss}')
-
- if epoch % 2 == 0:
- correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- logger.info(
- f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
- f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
-
-
-if __name__ == '__main__':
- test_2d_parallel_vision_transformer()
diff --git a/tests/test_data_pipeline_tensor_parallel/test_vit_2p5d/test_vit_2p5d.py b/tests/test_data_pipeline_tensor_parallel/test_vit_2p5d/test_vit_2p5d.py
deleted file mode 100644
index 70857f1e8..000000000
--- a/tests/test_data_pipeline_tensor_parallel/test_vit_2p5d/test_vit_2p5d.py
+++ /dev/null
@@ -1,89 +0,0 @@
-from pathlib import Path
-
-import pytest
-import torch.autograd
-
-import colossalai
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.nn.layer._parallel_utilities import _gather
-
-CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
-
-
-def eval(engine, test_dataloader):
- engine.eval()
- accumulated_loss = 0
- correct_sum = 0
- total_sum = 0
- num_steps = len(test_dataloader)
- data_iter = iter(test_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
-
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- accumulated_loss += loss.detach().cpu().numpy()
-
- output = _gather(
- output,
- ParallelMode.PARALLEL_2P5D_ROW,
- 1
- )
- output = _gather(
- output,
- ParallelMode.PARALLEL_2P5D_COL,
- 0,
- )
- output = _gather(
- output,
- ParallelMode.PARALLEL_2P5D_DEP,
- 0,
- )
- output = torch.argmax(output, dim=-1)
- correct = torch.sum(label == output)
- correct_sum += correct
- total_sum += label.size(0)
- avg_loss = accumulated_loss / num_steps
- return correct_sum, total_sum, avg_loss
-
-
-def train(engine, train_dataloader):
- engine.train()
- accumulated_loss = 0
- num_steps = len(train_dataloader)
- data_iter = iter(train_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
-
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- accumulated_loss += loss.detach().cpu().numpy()
-
- avg_loss = accumulated_loss / num_steps
- return avg_loss
-
-
-@pytest.mark.dist
-@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
-def test_2p5d_parallel_vision_transformer():
- # init dist
- engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
- logger = get_global_dist_logger()
-
- for epoch in range(gpc.config.num_epochs):
- train_loss = train(engine, train_dataloader)
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- logger.info(f'epoch {epoch} - train loss: {train_loss}')
-
- if epoch % 2 == 0:
- correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- logger.info(
- f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
- f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
-
-
-if __name__ == '__main__':
- test_2p5d_parallel_vision_transformer()
diff --git a/tests/test_data_pipeline_tensor_parallel/vit_t_2d.py b/tests/test_data_pipeline_tensor_parallel/vit_t_2d.py
new file mode 100644
index 000000000..5be7a575a
--- /dev/null
+++ b/tests/test_data_pipeline_tensor_parallel/vit_t_2d.py
@@ -0,0 +1,74 @@
+
+import sys
+from pathlib import Path
+repo_path = str(Path(__file__).absolute().parents[2])
+sys.path.append(repo_path)
+
+try:
+ import model_zoo.vit.vision_transformer_from_config
+except ImportError:
+ raise ImportError("model_zoo is not found, please check your path")
+
+IMG_SIZE = 32
+PATCH_SIZE = 4
+DIM = 512
+NUM_ATTENTION_HEADS = 8
+NUM_CLASSES = 10
+DEPTH = 6
+
+model_cfg = dict(
+ type='VisionTransformerFromConfig',
+ tensor_splitting_cfg=dict(
+ type='ViTInputSplitter2D',
+ ),
+ embedding_cfg=dict(
+ type='ViTPatchEmbedding2D',
+ img_size=IMG_SIZE,
+ patch_size=PATCH_SIZE,
+ embed_dim=DIM,
+ ),
+ token_fusion_cfg=dict(
+ type='ViTTokenFuser2D',
+ img_size=IMG_SIZE,
+ patch_size=PATCH_SIZE,
+ embed_dim=DIM,
+ drop_rate=0.1
+ ),
+ norm_cfg=dict(
+ type='LayerNorm2D',
+ normalized_shape=DIM,
+ eps=1e-6,
+ ),
+ block_cfg=dict(
+ type='ViTBlock',
+ attention_cfg=dict(
+ type='ViTSelfAttention2D',
+ hidden_size=DIM,
+ num_attention_heads=NUM_ATTENTION_HEADS,
+ attention_dropout_prob=0.,
+ hidden_dropout_prob=0.1,
+ ),
+ droppath_cfg=dict(
+ type='VanillaViTDropPath',
+ ),
+ mlp_cfg=dict(
+ type='ViTMLP2D',
+ in_features=DIM,
+ dropout_prob=0.1,
+ mlp_ratio=1
+ ),
+ norm_cfg=dict(
+ type='LayerNorm2D',
+ normalized_shape=DIM,
+ eps=1e-6,
+ ),
+ ),
+ head_cfg=dict(
+ type='ViTHead2D',
+ hidden_size=DIM,
+ num_classes=NUM_CLASSES,
+ ),
+ embed_dim=DIM,
+ depth=DEPTH,
+ drop_path_rate=0.,
+)
diff --git a/tests/test_engine/configs/non_pipeline_resnet_apex_amp.py b/tests/test_engine/configs/non_pipeline_resnet_apex_amp.py
index f845d9842..1415bcb85 100644
--- a/tests/test_engine/configs/non_pipeline_resnet_apex_amp.py
+++ b/tests/test_engine/configs/non_pipeline_resnet_apex_amp.py
@@ -1,7 +1,6 @@
import os
from pathlib import Path
-from colossalai.engine import AMP_TYPE
BATCH_SIZE = 128
IMG_SIZE = 224
@@ -9,34 +8,9 @@ DIM = 768
NUM_CLASSES = 10
NUM_ATTN_HEADS = 12
-# resnet 18
-model = dict(type='VanillaResNet',
- block_type='ResNetBasicBlock',
- layers=[2, 2, 2, 2],
- num_cls=10)
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
-
-train_data = dict(dataset=dict(type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- download=True,
- transform_pipeline=[
- dict(type='Resize',
- size=(IMG_SIZE, IMG_SIZE)),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=(0.5, 0.5, 0.5),
- std=(0.5, 0.5, 0.5))
- ]),
- dataloader=dict(batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- drop_last=True))
-
-optimizer = dict(type='Adam', lr=0.001)
-
-loss = dict(type='CrossEntropyLoss')
fp16 = dict(mode=AMP_TYPE.APEX)
diff --git a/tests/test_engine/test.sh b/tests/test_engine/test.sh
index 24d0c5423..0d90c8e55 100644
--- a/tests/test_engine/test.sh
+++ b/tests/test_engine/test.sh
@@ -1,4 +1,4 @@
#!/usr/bin/env sh
test_file=$1
-python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
+python $test_file --world_size $SLURM_NPROCS --host $HOST --port 29500 --rank $SLURM_PROCID
\ No newline at end of file
diff --git a/tests/test_engine/test_engine/test_engine_apex_amp.py b/tests/test_engine/test_engine/test_engine_apex_amp.py
new file mode 100644
index 000000000..ff9c9f9bf
--- /dev/null
+++ b/tests/test_engine/test_engine/test_engine_apex_amp.py
@@ -0,0 +1,114 @@
+# !/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import colossalai
+import os
+import pytest
+import torch
+import os.path as osp
+from pathlib import Path
+import torch.nn as nn
+
+from torchvision import transforms
+from torch.optim import Adam
+from colossalai.core import global_context as gpc
+from colossalai.amp import AMP_TYPE
+from colossalai.logging import get_dist_logger
+from colossalai.utils import report_memory_usage, get_dataloader
+from colossalai.initialize import get_default_parser
+from torchvision.models import resnet18
+from torchvision.datasets import CIFAR10
+
+
+# Config
+BATCH_SIZE = 128
+IMG_SIZE = 224
+DIM = 768
+NUM_CLASSES = 10
+NUM_ATTN_HEADS = 12
+
+CONFIG = dict(
+ parallel=dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=1, mode=None)
+ ),
+ fp16=dict(mode=AMP_TYPE.APEX),
+ clip_grad_norm=1.0
+)
+
+
+def run_no_pipeline():
+ parser = get_default_parser()
+ args = parser.parse_args()
+
+ # init dist env
+ colossalai.launch(
+ config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend
+ )
+
+ # build model
+ model = resnet18(num_classes=10)
+
+ # build dataloaders
+ train_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ drop_last=True)
+
+ # build optimizer
+ optimizer = Adam(model.parameters(), lr=0.001)
+ criterion = nn.CrossEntropyLoss()
+
+ engine, train_dataloader, *args = colossalai.initialize(
+ model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader
+ )
+ logger = get_dist_logger()
+ rank = torch.distributed.get_rank()
+
+ engine.train()
+ for img, label in train_dataloader:
+ engine.zero_grad()
+ img = img.cuda()
+ label = label.cuda()
+ output = engine(img)
+ loss = engine.criterion(output, label)
+ engine.backward(loss)
+ engine.step()
+ break
+
+ logger.info('Rank {} returns: {}'.format(rank, loss.item()))
+
+ gpc.destroy()
+ logger.info('Test engine finished')
+ report_memory_usage("After testing")
+
+
+@pytest.mark.skip("This test should be invoked using the test.sh provided")
+@pytest.mark.dist
+def test_engine():
+ run_no_pipeline()
+
+
+if __name__ == '__main__':
+ test_engine()
diff --git a/tests/test_engine/test_engine/test_engine_naive_amp.py b/tests/test_engine/test_engine/test_engine_naive_amp.py
new file mode 100644
index 000000000..dd75b9359
--- /dev/null
+++ b/tests/test_engine/test_engine/test_engine_naive_amp.py
@@ -0,0 +1,113 @@
+import colossalai
+import os
+import pytest
+import torch
+import os.path as osp
+from pathlib import Path
+import torch.nn as nn
+
+from torchvision import transforms
+from torch.optim import Adam
+from colossalai.core import global_context as gpc
+from colossalai.amp import AMP_TYPE
+from colossalai.logging import get_dist_logger
+from colossalai.utils import report_memory_usage, get_dataloader
+from colossalai.initialize import get_default_parser
+from torchvision.models import resnet18
+from torchvision.datasets import CIFAR10
+
+
+# Config
+BATCH_SIZE = 128
+IMG_SIZE = 224
+DIM = 768
+NUM_CLASSES = 10
+NUM_ATTN_HEADS = 12
+
+CONFIG = dict(
+ parallel=dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=1, mode=None)
+ ),
+ fp16=dict(
+ mode=AMP_TYPE.NAIVE,
+ clip_grad=1.0
+ )
+)
+
+
+def run_no_pipeline():
+ parser = get_default_parser()
+ args = parser.parse_args()
+
+ # init dist env
+ colossalai.launch(
+ config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend
+ )
+
+ # build model
+ model = resnet18(num_classes=10)
+
+ # build dataloaders
+ train_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ drop_last=True)
+
+ # build optimizer
+ optimizer = Adam(model.parameters(), lr=0.001)
+ criterion = nn.CrossEntropyLoss()
+
+ engine, train_dataloader, *args = colossalai.initialize(
+ model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader
+ )
+ logger = get_dist_logger()
+ rank = torch.distributed.get_rank()
+
+ engine.train()
+ for img, label in train_dataloader:
+ engine.zero_grad()
+ img = img.cuda()
+ label = label.cuda()
+ output = engine(img)
+ loss = engine.criterion(output, label)
+ engine.backward(loss)
+ engine.step()
+ break
+
+ logger.info('Rank {} returns: {}'.format(rank, loss.item()))
+
+ gpc.destroy()
+ logger.info('Test engine finished')
+ report_memory_usage("After testing")
+
+
+@pytest.mark.skip("This test should be invoked using the test.sh provided")
+@pytest.mark.dist
+def test_engine():
+ run_no_pipeline()
+
+
+if __name__ == '__main__':
+ test_engine()
diff --git a/tests/test_engine/test_engine/test_engine_no_amp.py b/tests/test_engine/test_engine/test_engine_no_amp.py
new file mode 100644
index 000000000..f8392c98a
--- /dev/null
+++ b/tests/test_engine/test_engine/test_engine_no_amp.py
@@ -0,0 +1,110 @@
+import colossalai
+import os
+import pytest
+import torch
+import os.path as osp
+from pathlib import Path
+import torch.nn as nn
+
+from torchvision import transforms
+from torch.optim import Adam
+from colossalai.core import global_context as gpc
+from colossalai.amp import AMP_TYPE
+from colossalai.logging import get_dist_logger
+from colossalai.utils import report_memory_usage, get_dataloader
+from colossalai.initialize import get_default_parser
+from torchvision.models import resnet18
+from torchvision.datasets import CIFAR10
+
+
+# Config
+BATCH_SIZE = 128
+IMG_SIZE = 224
+DIM = 768
+NUM_CLASSES = 10
+NUM_ATTN_HEADS = 12
+
+CONFIG = dict(
+ parallel=dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=1, mode=None)
+ ),
+ clip_grad_norm=1.0
+)
+
+
+def run_no_pipeline():
+ parser = get_default_parser()
+ args = parser.parse_args()
+
+ # init dist env
+ colossalai.launch(
+ config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend
+ )
+
+ # build model
+ model = resnet18(num_classes=10)
+
+ # build dataloaders
+ train_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ drop_last=True)
+
+ # build optimizer
+ optimizer = Adam(model.parameters(), lr=0.001)
+ criterion = nn.CrossEntropyLoss()
+
+ engine, train_dataloader, *args = colossalai.initialize(
+ model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader
+ )
+ logger = get_dist_logger()
+ rank = torch.distributed.get_rank()
+
+ engine.train()
+ for img, label in train_dataloader:
+ engine.zero_grad()
+ img = img.cuda()
+ label = label.cuda()
+ output = engine(img)
+ loss = engine.criterion(output, label)
+ engine.backward(loss)
+ engine.step()
+ break
+
+ logger.info('Rank {} returns: {}'.format(rank, loss.item()))
+
+ gpc.destroy()
+ logger.info('Test engine finished')
+ report_memory_usage("After testing")
+
+
+@pytest.mark.skip("This test should be invoked using the test.sh provided")
+@pytest.mark.dist
+def test_engine():
+ run_no_pipeline()
+
+
+if __name__ == '__main__':
+ test_engine()
diff --git a/tests/test_engine/test_engine/test_engine_torch_amp.py b/tests/test_engine/test_engine/test_engine_torch_amp.py
new file mode 100644
index 000000000..fdafd494c
--- /dev/null
+++ b/tests/test_engine/test_engine/test_engine_torch_amp.py
@@ -0,0 +1,111 @@
+import colossalai
+import os
+import pytest
+import torch
+import os.path as osp
+from pathlib import Path
+import torch.nn as nn
+
+from torchvision import transforms
+from torch.optim import Adam
+from colossalai.core import global_context as gpc
+from colossalai.amp import AMP_TYPE
+from colossalai.logging import get_dist_logger
+from colossalai.utils import report_memory_usage, get_dataloader
+from colossalai.initialize import get_default_parser
+from torchvision.models import resnet18
+from torchvision.datasets import CIFAR10
+
+
+# Config
+BATCH_SIZE = 128
+IMG_SIZE = 224
+DIM = 768
+NUM_CLASSES = 10
+NUM_ATTN_HEADS = 12
+
+CONFIG = dict(
+ parallel=dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=1, mode=None)
+ ),
+ fp16=dict(mode=AMP_TYPE.TORCH),
+ clip_grad_norm=1.0
+)
+
+
+def run_no_pipeline():
+ parser = get_default_parser()
+ args = parser.parse_args()
+
+ # init dist env
+ colossalai.launch(
+ config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend
+ )
+
+ # build model
+ model = resnet18(num_classes=10)
+
+ # build dataloaders
+ train_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ drop_last=True)
+
+ # build optimizer
+ optimizer = Adam(model.parameters(), lr=0.001)
+ criterion = nn.CrossEntropyLoss()
+
+ engine, train_dataloader, *args = colossalai.initialize(
+ model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader
+ )
+ logger = get_dist_logger()
+ rank = torch.distributed.get_rank()
+
+ engine.train()
+ for img, label in train_dataloader:
+ engine.zero_grad()
+ img = img.cuda()
+ label = label.cuda()
+ output = engine(img)
+ loss = engine.criterion(output, label)
+ engine.backward(loss)
+ engine.step()
+ break
+
+ logger.info('Rank {} returns: {}'.format(rank, loss.item()))
+
+ gpc.destroy()
+ logger.info('Test engine finished')
+ report_memory_usage("After testing")
+
+
+@pytest.mark.skip("This test should be invoked using the test.sh provided")
+@pytest.mark.dist
+def test_engine():
+ run_no_pipeline()
+
+
+if __name__ == '__main__':
+ test_engine()
diff --git a/tests/test_engine/test_non_pipeline_engine/test_engine_apex_amp.py b/tests/test_engine/test_non_pipeline_engine/test_engine_apex_amp.py
deleted file mode 100644
index 98c2b8072..000000000
--- a/tests/test_engine/test_non_pipeline_engine/test_engine_apex_amp.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# !/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import os.path as osp
-
-import pytest
-import torch
-
-from colossalai import initialize
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.utils import report_memory_usage
-
-NUM_BATCH = 128
-NUM_MICRO = 6
-
-BATCH_SIZE = 32
-SEQ_LENGTH = 128
-HIDDEN_SIZE = 512
-
-DIR_PATH = osp.dirname(osp.realpath(__file__))
-NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_amp.py')
-
-
-def run_no_pipeline(config):
- engine, train_dataloader, test_dataloader = initialize(config)
- logger = get_global_dist_logger()
- rank = torch.distributed.get_rank()
-
- engine.train()
- output, label, loss = engine.step(iter(train_dataloader))
- logger.info('Rank {} returns: {}'.format(rank, loss.item()))
-
- gpc.destroy()
- logger.info('Test engine finished')
- report_memory_usage("After testing")
-
-
-@pytest.mark.skip("This test should be invoked using the test.sh provided")
-@pytest.mark.dist
-def test_engine():
- run_no_pipeline(NO_PIPE_CONFIG_PATH)
-
-
-if __name__ == '__main__':
- test_engine()
diff --git a/tests/test_engine/test_non_pipeline_engine/test_engine_no_amp.py b/tests/test_engine/test_non_pipeline_engine/test_engine_no_amp.py
deleted file mode 100644
index effb65e02..000000000
--- a/tests/test_engine/test_non_pipeline_engine/test_engine_no_amp.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import os.path as osp
-
-import pytest
-import torch
-
-from colossalai import initialize
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.utils import report_memory_usage
-
-NUM_BATCH = 128
-NUM_MICRO = 6
-
-BATCH_SIZE = 32
-SEQ_LENGTH = 128
-HIDDEN_SIZE = 512
-
-DIR_PATH = osp.dirname(osp.realpath(__file__))
-NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py')
-
-
-def test_no_pipeline(config):
- print('Test no pipeline engine start')
-
- engine, train_dataloader, test_dataloader = initialize(config)
- logger = get_global_dist_logger()
-
- rank = torch.distributed.get_rank()
-
- engine.train()
- output, label, loss = engine.step(iter(train_dataloader))
- logger.info('Rank {} returns: {}'.format(rank, loss.item()))
-
- gpc.destroy()
- logger.info('Test engine finished')
- report_memory_usage("After testing")
-
-
-@pytest.mark.skip("This test should be invoked using the test.sh provided")
-@pytest.mark.dist
-def test_engine():
- test_no_pipeline(NO_PIPE_CONFIG_PATH)
-
-
-if __name__ == '__main__':
- test_engine()
diff --git a/tests/test_engine/test_non_pipeline_engine/test_engine_torch_amp.py b/tests/test_engine/test_non_pipeline_engine/test_engine_torch_amp.py
deleted file mode 100644
index a4c496a7d..000000000
--- a/tests/test_engine/test_non_pipeline_engine/test_engine_torch_amp.py
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import os.path as osp
-
-import pytest
-import torch
-
-from colossalai import initialize
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.utils import report_memory_usage
-
-NUM_BATCH = 128
-NUM_MICRO = 6
-
-BATCH_SIZE = 32
-SEQ_LENGTH = 128
-HIDDEN_SIZE = 512
-
-DIR_PATH = osp.dirname(osp.realpath(__file__))
-NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_amp.py')
-
-
-def test_no_pipeline(config):
- print('Test no pipeline engine start')
-
- engine, train_dataloader, test_dataloader = initialize(config)
- logger = get_global_dist_logger()
- rank = torch.distributed.get_rank()
-
- engine.train()
- output, label, loss = engine.step(iter(train_dataloader))
- logger.info('Rank {} returns: {}'.format(rank, loss.item()))
-
- gpc.destroy()
- logger.info('Test engine finished')
- report_memory_usage("After testing")
-
-
-@pytest.mark.skip("This test should be invoked using the test.sh provided")
-@pytest.mark.dist
-def test_engine():
- test_no_pipeline(NO_PIPE_CONFIG_PATH)
-
-
-if __name__ == '__main__':
- test_engine()
diff --git a/tests/test_engine/test_pipeline_engine/test_engine.py b/tests/test_engine/test_pipeline_engine/test_engine.py
deleted file mode 100644
index 9d6c9f59f..000000000
--- a/tests/test_engine/test_pipeline_engine/test_engine.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import os.path as osp
-
-import pytest
-import torch
-
-from colossalai import initialize
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-
-NUM_BATCH = 128
-
-BATCH_SIZE = 32
-SEQ_LENGTH = 128
-HIDDEN_SIZE = 512
-
-DIR_PATH = osp.dirname(osp.realpath(__file__))
-PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
-
-
-def run_pipeline(config):
- engine, train_dataloader, test_dataloader = initialize(config)
- logger = get_global_dist_logger()
- rank = torch.distributed.get_rank()
-
- engine.train()
- outputs, labels, loss = engine.step(iter(train_dataloader))
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- logger.info('losses: {}'.format(rank, loss.item()))
-
- gpc.destroy()
- logger.info('Test engine pipeline finished')
-
-
-@pytest.mark.skip("This test should be invoked using the test.sh provided")
-@pytest.mark.dist
-def test_engine():
- run_pipeline(PIPE_CONFIG_PATH)
-
-
-if __name__ == '__main__':
- test_engine()
diff --git a/tests/test_fp16_optimizer/configs/vit_2d.py b/tests/test_fp16_optimizer/configs/vit_2d.py
deleted file mode 100644
index 6283dea9b..000000000
--- a/tests/test_fp16_optimizer/configs/vit_2d.py
+++ /dev/null
@@ -1,143 +0,0 @@
-import os
-from pathlib import Path
-
-from colossalai.engine import AMP_TYPE
-
-BATCH_SIZE = 512
-IMG_SIZE = 32
-PATCH_SIZE = 4
-DIM = 512
-NUM_ATTENTION_HEADS = 8
-SUMMA_DIM = 2
-NUM_CLASSES = 10
-DEPTH = 6
-
-train_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- transform_pipeline=[
- dict(type='RandomCrop', size=IMG_SIZE, padding=4),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True
- )
-)
-
-test_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]
- ),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True
- )
-)
-
-optimizer = dict(
- type='Adam',
- lr=0.001,
- weight_decay=0
-)
-
-loss = dict(
- type='CrossEntropyLoss2D',
-)
-
-model = dict(
- type='VisionTransformerFromConfig',
- tensor_splitting_cfg=dict(
- type='ViTInputSplitter2D',
- ),
- embedding_cfg=dict(
- type='ViTPatchEmbedding2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- ),
- token_fusion_cfg=dict(
- type='ViTTokenFuser2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- drop_rate=0.1
- ),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- block_cfg=dict(
- type='ViTBlock',
- attention_cfg=dict(
- type='ViTSelfAttention2D',
- hidden_size=DIM,
- num_attention_heads=NUM_ATTENTION_HEADS,
- attention_dropout_prob=0.,
- hidden_dropout_prob=0.1,
- ),
- droppath_cfg=dict(
- type='VanillaViTDropPath',
- ),
- mlp_cfg=dict(
- type='ViTMLP2D',
- in_features=DIM,
- dropout_prob=0.1,
- mlp_ratio=1
- ),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- ),
- head_cfg=dict(
- type='ViTHead2D',
- hidden_size=DIM,
- num_classes=NUM_CLASSES,
- ),
- embed_dim=DIM,
- depth=DEPTH,
- drop_path_rate=0.,
-)
-
-parallel = dict(
- pipeline=dict(size=1),
- tensor=dict(size=4, mode='2d'),
-)
-
-fp16 = dict(
- mode=AMP_TYPE.PARALLEL,
- initial_scale=2 ** 4
-)
-
-num_epochs = 60
-
-
-lr_scheduler = dict(
- type='LinearWarmupLR',
- warmup_steps=5,
- total_steps=num_epochs
-)
-
diff --git a/tests/test_fp16_optimizer/test.sh b/tests/test_fp16_optimizer/test.sh
deleted file mode 100644
index 24d0c5423..000000000
--- a/tests/test_fp16_optimizer/test.sh
+++ /dev/null
@@ -1,4 +0,0 @@
-#!/usr/bin/env sh
-test_file=$1
-
-python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
diff --git a/tests/test_fp16_optimizer/test_vit_2d/test_vit_2d.py b/tests/test_fp16_optimizer/test_vit_2d/test_vit_2d.py
deleted file mode 100644
index 45c36f384..000000000
--- a/tests/test_fp16_optimizer/test_vit_2d/test_vit_2d.py
+++ /dev/null
@@ -1,85 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from pathlib import Path
-
-import pytest
-import torch.autograd
-
-import colossalai
-from colossalai.builder import build_lr_scheduler
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.nn.layer._parallel_utilities import _gather
-
-CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
-
-
-def eval(engine, test_dataloader):
- engine.eval()
- accumulated_loss = 0
- correct_sum = 0
- total_sum = 0
- num_steps = len(test_dataloader)
- data_iter = iter(test_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
- accumulated_loss += loss.detach().cpu().numpy()
-
- output = _gather(
- output[0],
- ParallelMode.PARALLEL_2D_ROW,
- 1
- )
- output = _gather(
- output,
- ParallelMode.PARALLEL_2D_COL,
- 0,
- )
- output = torch.argmax(output, dim=-1)
- correct = torch.sum(label[0] == output)
- correct_sum += correct
- total_sum += label[0].size(0)
- avg_loss = accumulated_loss / num_steps
- return correct_sum, total_sum, avg_loss
-
-
-def train(engine, train_dataloader, lr_scheduler):
- engine.train()
- accumulated_loss = 0
- num_steps = len(train_dataloader)
- data_iter = iter(train_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
- accumulated_loss += loss.squeeze(0).detach().cpu().numpy()
- avg_loss = accumulated_loss / num_steps
- lr_scheduler.step()
- return avg_loss
-
-
-@pytest.mark.dist
-@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
-def test_2d_parallel_vision_transformer():
- # init dist
- engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
- lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
- logger = get_global_dist_logger()
-
- logger.info('start training')
- for epoch in range(gpc.config.num_epochs):
- train_loss = train(engine, train_dataloader, lr_scheduler)
-
- logger.info(f'epoch {epoch} - train loss: {train_loss}')
-
- if epoch % 2 == 0:
- correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
- logger.info(
- f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
- f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
-
-
-if __name__ == '__main__':
- test_2d_parallel_vision_transformer()
diff --git a/tests/test_layers/test.sh b/tests/test_layers/test.sh
index 24d0c5423..da5afd5ae 100644
--- a/tests/test_layers/test.sh
+++ b/tests/test_layers/test.sh
@@ -1,4 +1,4 @@
#!/usr/bin/env sh
test_file=$1
-python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
+python $test_file --rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
diff --git a/tests/test_layers/test_1d/common.py b/tests/test_layers/test_1d/common.py
index 64d4601cb..a17cae9d3 100644
--- a/tests/test_layers/test_1d/common.py
+++ b/tests/test_layers/test_1d/common.py
@@ -6,8 +6,9 @@ import torch
DEPTH = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
+IMG_SIZE = 16
HIDDEN_SIZE = 8
-
+NUM_CLASSES = 10
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py
index e89cfe972..533376999 100644
--- a/tests/test_layers/test_1d/test_1d.py
+++ b/tests/test_layers/test_1d/test_1d.py
@@ -4,8 +4,8 @@
import pytest
from colossalai.core import global_context as gpc
-from colossalai.initialize import init_dist
-from test_layer import check_linear_col, check_linear_row
+from colossalai.initialize import launch, get_default_parser
+from test_layer import *
CONFIG = dict(
parallel=dict(
@@ -19,20 +19,31 @@ CONFIG = dict(
def check_layer():
+ # print_rank_0('start check_linear_col')
check_linear_col()
check_linear_row()
- # check_attention()
- # check_mlp()
+ check_attention()
+ check_mlp()
+ check_patch_embedding()
+ check_embed()
+ check_head()
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
-def test_2d():
- init_dist(config=CONFIG)
- gpc.set_seed()
+def test_1d():
+ parser = get_default_parser()
+ args = parser.parse_args()
+ launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend)
+
check_layer()
gpc.destroy()
if __name__ == '__main__':
- test_2d()
+ test_1d()
diff --git a/tests/test_layers/test_1d/test_layer.py b/tests/test_layers/test_1d/test_layer.py
index 59551a5ca..682a4257a 100644
--- a/tests/test_layers/test_1d/test_layer.py
+++ b/tests/test_layers/test_1d/test_layer.py
@@ -1,14 +1,13 @@
+from tests.test_layers.test_3d.common import IMG_SIZE
import torch
import torch.distributed as dist
from torch.nn import Parameter
-
+import time
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn import Linear1D_Col, Linear1D_Row
-# TransformerMLP1D, \
-# TransformerSelfAttention1D, TransformerEncoderLayer1D
+from colossalai.nn import Linear1D_Col, Linear1D_Row, TransformerMLP1D, TransformerSelfAttention1D, ViTMLP1D, ViTSelfAttention1D, ViTPatchEmbedding1D, ViTHead1D, ViTTokenFuser1D
from colossalai.utils import get_current_device, print_rank_0
-from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal
+from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE
def check_linear_col():
@@ -142,70 +141,274 @@ def check_linear_row():
print_rank_0('linear_row no parallel_input backward: pass')
-#
-# def check_attention():
-# device = get_current_device()
-# dtype = torch.float32
-# INPUT_SIZE = HIDDEN_SIZE
-# NUM_ATTENTION_HEADS = 2
-#
-# i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
-#
-# layer = TransformerSelfAttention1D(
-# 1,
-# HIDDEN_SIZE // NUM_ATTENTION_HEADS,
-# HIDDEN_SIZE,
-# NUM_ATTENTION_HEADS,
-# 0.5
-# )
-#
-# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
-# A_master = torch.randn(A_shape, dtype=dtype, device=device)
-# torch.distributed.broadcast(A_master, src=0)
-# A = A_master.clone()
-# A.requires_grad = True
-#
-# mask_shape = (BATCH_SIZE, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
-# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
-#
-# out = layer(A, attention_mask)
-# assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
-# print_rank_0('self attention forward: pass')
-#
-# grad_shape = out.shape
-# grad = torch.randn(grad_shape, dtype=dtype, device=device)
-#
-# out.backward(grad)
-# assert A.grad.shape == A.shape
-# print_rank_0('self attention backward: pass')
-#
-#
-# def check_mlp():
-# device = get_current_device()
-# dtype = torch.float32
-# INPUT_SIZE = HIDDEN_SIZE
-#
-# i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
-#
-# layer = TransformerMLP1D(
-# HIDDEN_SIZE,
-# HIDDEN_SIZE,
-# 4.0
-# )
-#
-# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
-# A_master = torch.randn(A_shape, dtype=dtype, device=device)
-# torch.distributed.broadcast(A_master, src=0)
-# A = A_master.clone()
-# A.requires_grad = True
-#
-# out = layer(A)
-# assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
-# print_rank_0('mlp forward: pass')
-#
-# grad_shape = out.shape
-# grad = torch.randn(grad_shape, dtype=dtype, device=device)
-#
-# out.backward(grad)
-# assert A.grad.shape == A.shape
-# print_rank_0('mlp backward: pass')
+
+class Testvithead(torch.nn.Module):
+ def __init__(self, in_features, out_features, bias=True):
+ super().__init__()
+ self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
+
+ def forward(self, x):
+ x = x[:, 0]
+ x = self.linear(x)
+ return x
+
+
+def check_head():
+ device = get_current_device()
+ dtype = torch.float32
+ INPUT_SIZE = HIDDEN_SIZE
+
+ i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
+
+ head = ViTHead1D(INPUT_SIZE, NUM_CLASSES, dtype=dtype)
+ torch.nn.init.zeros_(head.linear.bias)
+ torch.nn.init.ones_(head.linear.weight)
+ head = head.to(device)
+
+ layer = Testvithead(INPUT_SIZE, NUM_CLASSES, bias=True)
+ torch.nn.init.zeros_(layer.linear.bias)
+ torch.nn.init.ones_(layer.linear.weight)
+ layer = layer.to(device)
+
+ A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
+ A_master = torch.randn(A_shape, dtype=dtype, device=device)
+ torch.distributed.broadcast(A_master, src=0)
+ A = A_master.clone()
+ A.requires_grad = True
+
+ fwd_start = time.time()
+ out = head(A)
+ fwd_end = time.time()
+ print_rank_0(
+ 'head forward: pass | {0} --> {1} | {2:.3f} s'.format(
+ tuple(A.shape), tuple(out.shape), fwd_end - fwd_start))
+ A_master = A_master.clone()
+ A_master.requires_grad = True
+ C_master = layer(A_master)
+ # C = torch.chunk(C_master, DEPTH, dim=0)[i]
+ print_rank_0('Rank {} head forward: {}'.format(i, check_equal(out, C_master)))
+
+ grad_shape = C_master.shape
+ grad_master = torch.randn(grad_shape,
+ dtype=dtype,
+ device=get_current_device())
+ torch.distributed.broadcast(grad_master, src=0)
+ # grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
+
+ # bwd_start = time.time()
+ out.backward(grad_master)
+ # bwd_end = time.time()
+ # print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
+ # logger)
+
+ C_master.backward(grad_master)
+ A_grad = A_master.grad
+ # if j == 0:
+ print_rank_0('Rank {} head backward (input_grad): {}'.format(
+ i, check_equal(A_grad, A.grad)))
+
+
+class Testvitembed(torch.nn.Module):
+ def __init__(self, img_size: int, patch_size: int, in_chans: int,
+ embed_size: int, drop_prob: float) -> None:
+ super().__init__()
+ self.proj = torch.nn.Conv2d(in_chans,
+ embed_size,
+ kernel_size=patch_size,
+ stride=patch_size)
+ num_patches = (img_size // patch_size)**2
+ self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_size))
+ self.pos_embed = torch.nn.Parameter(
+ torch.zeros(1, num_patches + 1, embed_size))
+ self.pos_drop = torch.nn.Dropout(drop_prob)
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(2).transpose(1, 2)
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_token, x), dim=1)
+ x = self.pos_drop(x + self.pos_embed)
+ return x
+
+
+def check_embed():
+ device = get_current_device()
+ dtype = torch.float32
+ i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
+
+ layer = ViTPatchEmbedding1D(IMG_SIZE, 4, HIDDEN_SIZE)
+ layer2 = ViTTokenFuser1D(IMG_SIZE, 4, HIDDEN_SIZE)
+ torch.nn.init.zeros_(layer.proj.bias)
+ torch.nn.init.ones_(layer.proj.weight)
+ torch.nn.init.ones_(layer2.cls_token)
+ torch.nn.init.ones_(layer2.pos_embed)
+ layer = layer.to(device)
+ layer2 = layer2.to(device)
+
+ layer_master = Testvitembed(IMG_SIZE, 4, 3, HIDDEN_SIZE, 0.)
+ torch.nn.init.zeros_(layer_master.proj.bias)
+ torch.nn.init.ones_(layer_master.proj.weight)
+ torch.nn.init.ones_(layer_master.cls_token)
+ torch.nn.init.ones_(layer_master.pos_embed)
+ layer_master = layer_master.to(device)
+
+ A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
+ A_master = torch.randn(A_shape, dtype=dtype, device=device)
+ torch.distributed.broadcast(A_master, src=0)
+ A = A_master.clone()
+ A.requires_grad = True
+
+ fwd_start = time.time()
+ out = layer2(layer(A))
+ fwd_end = time.time()
+ print_rank_0(
+ 'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(
+ tuple(A.shape), tuple(out.shape), fwd_end - fwd_start))
+ # out_cls = out[:, 0]
+ # out_tensor = out[:, 1:]
+
+ A_master = A_master.clone()
+ A_master.requires_grad = True
+ C_master = layer_master(A_master)
+ # if j == 0:
+ # C_cls = C_master[:, 0]
+ # C_cls = torch.chunk(C_cls, DEPTH, dim=0)[i]
+ # C_cls = torch.chunk(C_cls, DEPTH, dim=-1)[k]
+ # logger.info('Rank {} embed forward (cls): {}'.format(
+ # rank, check_equal(out_cls, C_cls)))
+ # C = C_master[:, 1:]
+ print_rank_0('Rank {} embed forward: {}'.format(i, check_equal(out, C_master)))
+
+ grad_shape = C_master.shape
+ grad_master = torch.randn(grad_shape,
+ dtype=dtype,
+ device=get_current_device())
+ torch.distributed.broadcast(grad_master, src=0)
+ # cls_grad = grad_master[:, 0]
+ # cls_grad = torch.chunk(cls_grad, DEPTH, dim=0)[i]
+ # cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[k]
+ # grad = grad_master[:, 1:]
+ # grad = torch.cat((torch.unsqueeze(cls_grad, 1), grad), dim=1)
+ bwd_start = time.time()
+ out.backward(grad_master)
+ bwd_end = time.time()
+ print_rank_0(
+ 'embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start))
+
+ C_master.backward(grad_master)
+
+ A_grad = A_master.grad
+ print_rank_0('Rank {} embed backward (input_grad): {}'.format(i, check_equal(A_grad, A.grad)))
+
+ print_rank_0('Rank {} embed backward (cls_grad): {}'.format(
+ i, check_equal(layer_master.cls_token.grad, layer2.cls_token.grad)))
+
+ print_rank_0('Rank {} embed backward (pos_embed_grad): {}'.format(
+ i, check_equal(layer_master.pos_embed.grad, layer2.pos_embed.grad)))
+
+ print_rank_0('Rank {} embed backward (proj_weight_grad): {}'.format(
+ i, check_equal(layer_master.proj.weight.grad, layer.proj.weight.grad)))
+
+ print_rank_0('Rank {} embed backward (proj_bias_grad): {}'.format(
+ i, check_equal(layer_master.proj.bias.grad, layer.proj.bias.grad)))
+
+ return fwd_end - fwd_start, bwd_end - bwd_start
+
+
+def check_attention():
+ device = get_current_device()
+ dtype = torch.float32
+ INPUT_SIZE = HIDDEN_SIZE
+ NUM_ATTENTION_HEADS = 2
+
+ i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
+
+ layer = ViTSelfAttention1D(
+ HIDDEN_SIZE,
+ NUM_ATTENTION_HEADS,
+ 0.5,
+ 0.5
+ ).to(device=device)
+
+ A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
+ A_master = torch.randn(A_shape, dtype=dtype, device=device)
+ torch.distributed.broadcast(A_master, src=0)
+ A = A_master.clone()
+ A.requires_grad = True
+
+ mask_shape = (BATCH_SIZE, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
+ attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
+
+ out = layer(A)
+ assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
+ print_rank_0('self attention forward: pass')
+
+ grad_shape = out.shape
+ grad = torch.randn(grad_shape, dtype=dtype, device=device)
+
+ out.backward(grad)
+ assert A.grad.shape == A.shape
+ print_rank_0('self attention backward: pass')
+
+
+def check_mlp():
+ device = get_current_device()
+ dtype = torch.float32
+ INPUT_SIZE = HIDDEN_SIZE
+
+ i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
+
+ layer = ViTMLP1D(
+ HIDDEN_SIZE,
+ 4.0
+ ).to(device=device)
+
+ A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
+ A_master = torch.randn(A_shape, dtype=dtype, device=device)
+ torch.distributed.broadcast(A_master, src=0)
+ A = A_master.clone()
+ A.requires_grad = True
+
+ out = layer(A)
+ assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
+ print_rank_0('mlp forward: pass')
+
+ grad_shape = out.shape
+ grad = torch.randn(grad_shape, dtype=dtype, device=device)
+
+ out.backward(grad)
+ assert A.grad.shape == A.shape
+ print_rank_0('mlp backward: pass')
+
+
+def check_patch_embedding():
+ device = get_current_device()
+ dtype = torch.float32
+ INPUT_SIZE = 4
+ PATCH_SIZE = 2
+
+ i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
+
+ layer = ViTPatchEmbedding1D(
+ INPUT_SIZE,
+ PATCH_SIZE,
+ HIDDEN_SIZE,
+ ).to(device=device)
+
+ A_shape = (BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE)
+ A_master = torch.randn(A_shape, dtype=dtype, device=device)
+ torch.distributed.broadcast(A_master, src=0)
+ A = A_master.clone()
+ A.requires_grad = True
+
+ out = layer(A)
+ print('output size: ', out.size())
+ assert out.shape == (BATCH_SIZE, 4, HIDDEN_SIZE)
+ print_rank_0('patch embedding forward: pass')
+
+ grad_shape = out.shape
+ grad = torch.randn(grad_shape, dtype=dtype, device=device)
+
+ out.backward(grad)
+ assert A.grad.shape == A.shape
+ print_rank_0('patch embedding backward: pass')
diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_layers/test_2d/test_2d.py
index 994b2d37a..f1b683b9f 100644
--- a/tests/test_layers/test_2d/test_2d.py
+++ b/tests/test_layers/test_2d/test_2d.py
@@ -4,7 +4,7 @@
import pytest
from colossalai.core import global_context as gpc
-from colossalai.initialize import init_dist
+from colossalai.initialize import launch, get_default_parser
from test_layer import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
from test_operation import check_AB, check_ABT, check_ATB
@@ -36,8 +36,14 @@ def check_layer():
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d():
- init_dist(config=CONFIG)
- gpc.set_seed()
+ parser = get_default_parser()
+ args = parser.parse_args()
+ launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend)
check_operations()
check_layer()
gpc.destroy()
diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_layers/test_2p5d/test_2p5d.py
index 488d38d87..bad2a9a04 100644
--- a/tests/test_layers/test_2p5d/test_2p5d.py
+++ b/tests/test_layers/test_2p5d/test_2p5d.py
@@ -1,7 +1,7 @@
import pytest
from colossalai.core import global_context as gpc
-from colossalai.initialize import init_dist
+from colossalai.initialize import launch, get_default_parser
from test_layer import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
from test_operation import check_AB, check_ABT, check_ATB
@@ -30,8 +30,14 @@ def check_layer():
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2p5d():
- init_dist(config=CONFIG)
- gpc.set_seed()
+ parser = get_default_parser()
+ args = parser.parse_args()
+ launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend)
check_layer()
check_operations()
gpc.destroy()
diff --git a/tests/test_layers/test_2p5d/test_operation.py b/tests/test_layers/test_2p5d/test_operation.py
index 5ffaafe2c..2342db3bb 100644
--- a/tests/test_layers/test_2p5d/test_operation.py
+++ b/tests/test_layers/test_2p5d/test_operation.py
@@ -16,7 +16,7 @@ def check_AB():
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
-
+
dtype = torch.float
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -41,11 +41,10 @@ def check_AB():
out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM)
out = Matmul_AB_2p5D.apply(
A, B,
- TESSERACT_DIM, TESSERACT_DEP, out_shape,
+ TESSERACT_DIM, out_shape,
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
- ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
@@ -93,7 +92,7 @@ def check_ABT():
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
-
+
dtype = torch.float
device = get_current_device()
@@ -119,11 +118,10 @@ def check_ABT():
out = Matmul_ABT_2p5D.apply(
C, B,
- TESSERACT_DIM, TESSERACT_DEP, (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM),
+ TESSERACT_DIM, (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM),
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
- ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
@@ -169,7 +167,7 @@ def check_ATB():
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
-
+
device = get_current_device()
dtype = torch.float
@@ -195,11 +193,10 @@ def check_ATB():
out = Matmul_ATB_2p5D.apply(
A, C,
- TESSERACT_DIM, TESSERACT_DEP, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
+ TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
- ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
diff --git a/tests/test_layers/test_3d/common.py b/tests/test_layers/test_3d/common.py
index c85046855..88c0f41c6 100644
--- a/tests/test_layers/test_3d/common.py
+++ b/tests/test_layers/test_3d/common.py
@@ -7,9 +7,9 @@ DEPTH = 2
BATCH_SIZE = 512
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
-NUM_CLASSES = 10
+NUM_CLASSES = 1000
NUM_BLOCKS = 6
-IMG_SIZE = 32
+IMG_SIZE = 224
def check_equal(A, B):
- return torch.allclose(A, B, rtol=1e-5, atol=1e-2)
+ return torch.allclose(A, B, rtol=1e-4, atol=1e-2)
diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py
index 21c560820..b05fc672a 100644
--- a/tests/test_layers/test_3d/test_3d.py
+++ b/tests/test_layers/test_3d/test_3d.py
@@ -1,27 +1,27 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-from colossalai.initialize import init_dist
+from colossalai.initialize import launch, get_default_parser
from test_layer import *
from test_operation import *
+from colossalai.logging import get_dist_logger
CONFIG = dict(parallel=dict(pipeline=1, tensor=dict(mode='3d', size=8)),
seed=0)
-def check_operations():
- check_AB()
- check_ABT()
- check_ATB()
- check_add()
- check_mul()
- check_sum()
- # check_pooler()
+# def check_operations():
+# check_AB()
+# check_ABT()
+# check_ATB()
+# check_add()
+# check_mul()
+# check_sum()
def check_layer():
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
liear_fwd_time, linear_bwd_time = check_linear()
norm_fwd_time, norm_bwd_time = check_layernorm()
attn_fwd_time, attn_bwd_time = check_attention()
@@ -40,15 +40,20 @@ def check_layer():
def _test_main():
# init dist
- init_dist(CONFIG)
- logger = get_global_dist_logger()
+ parser = get_default_parser()
+ args = parser.parse_args()
+ launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend)
+ logger = get_dist_logger()
logger.info('Distributed environment is initialzied.', ranks=[0])
-
- global_context.set_seed()
torch.backends.cudnn.benchmark = True
# check operation
- check_operations()
+ # check_operations()
# check layers
check_layer()
diff --git a/tests/test_layers/test_3d/test_conn.py b/tests/test_layers/test_3d/test_conn.py
index 83cb32dd5..c88368b93 100644
--- a/tests/test_layers/test_3d/test_conn.py
+++ b/tests/test_layers/test_3d/test_conn.py
@@ -1,19 +1,34 @@
+import time
+
import torch
import torch.distributed as dist
+from colossalai.communication import all_gather, reduce_scatter, all_reduce
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.initialize import init_dist, parse_args
+from colossalai.utils import get_current_device, print_rank_0
-from colossalai.initialize import parse_args
-from colossalai.utils import get_current_device
+# ARGS = parse_args()
+# size = ARGS.world_size
+# rank = ARGS.rank
-ARGS = parse_args()
-size = ARGS.world_size
-rank = ARGS.local_rank
+# init_method = f'tcp://{ARGS.host}:{ARGS.port}'
+# dist.init_process_group(backend='nccl', rank=rank, world_size=size, init_method=init_method)
+CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
+init_dist(CONFIG)
+
+assert dist.get_rank() == gpc.get_global_rank()
-init_method = f'tcp://{ARGS.host}:{ARGS.port}'
-dist.init_process_group(backend='nccl', rank=rank, world_size=size, init_method=init_method)
print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size()))
SIZE = 8
tensor = torch.randn(SIZE)
tensor = tensor.to(get_current_device())
-dist.all_reduce(tensor)
-print('Rank {0}: {1}'.format(rank, tensor.detach().cpu().numpy().tolist()))
+print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor))
+time.sleep(1)
+# tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True)
+# tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True)
+tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True)
+print_rank_0('After: Rank {0} - {1}'.format(dist.get_rank(), tensor))
+op.wait()
+print_rank_0('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor))
diff --git a/tests/test_layers/test_3d/test_layer.py b/tests/test_layers/test_3d/test_layer.py
index db5de22a4..92720e42c 100644
--- a/tests/test_layers/test_3d/test_layer.py
+++ b/tests/test_layers/test_3d/test_layer.py
@@ -7,39 +7,55 @@ import time
import numpy as np
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
from colossalai.registry import LAYERS, LOSSES
from colossalai.utils import get_current_device, print_rank_0
+from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
+from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from common import *
def check_linear():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
- j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
- i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
- k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
+ input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
+
+ j = A_rank = global_context.get_local_rank(input_parallel_mode)
+ i = B_rank = global_context.get_local_rank(weight_parallel_mode)
+ k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('Linear3D')(INPUT_SIZE,
OUTPUT_SIZE,
- ParallelMode.PARALLEL_3D_INPUT,
- ParallelMode.PARALLEL_3D_WEIGHT,
+ # ParallelMode.PARALLEL_3D_INPUT,
+ # ParallelMode.PARALLEL_3D_WEIGHT,
dtype=dtype,
bias=True)
- torch.nn.init.zeros_(layer.bias)
- torch.nn.init.ones_(layer.weight)
+ # torch.nn.init.zeros_(layer.bias)
+ # torch.nn.init.ones_(layer.weight)
layer = layer.to(device)
layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE)
- torch.nn.init.zeros_(layer_master.bias)
- torch.nn.init.ones_(layer_master.weight)
+ # torch.nn.init.zeros_(layer_master.bias)
+ # torch.nn.init.ones_(layer_master.weight)
layer_master = layer_master.to(device)
+ weight_master = layer_master.weight.data.transpose(0, 1)
+ torch.distributed.broadcast(weight_master, src=0)
+ weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
+ weight = torch.chunk(weight, DEPTH, dim=-1)[j]
+ layer.weight = torch.nn.Parameter(weight)
+ bias_master = layer_master.bias.data
+ torch.distributed.broadcast(bias_master, src=0)
+ bias = torch.chunk(bias_master, DEPTH)[j]
+ layer.bias = torch.nn.Parameter(bias)
+
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
@@ -89,45 +105,52 @@ def check_linear():
B_grad = layer_master.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
- B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
+ # B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
logger.info('Rank {} linear backward (weight_grad): {}'.format(
rank, check_equal(B_grad, layer.weight.grad)))
- if j == k:
- bias_grad = layer_master.bias.grad
- bias_grad = torch.chunk(bias_grad, DEPTH)[j]
- bias_grad = torch.chunk(bias_grad, DEPTH)[i]
- logger.info('Rank {} linear backward (bias_grad): {}'.format(
- rank, check_equal(bias_grad, layer.bias.grad)))
- else:
- logger.info('Rank {} linear backward (bias_grad): {}'.format(
- rank,
- # np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0))
- layer.bias.grad is None))
+ bias_grad = layer_master.bias.grad
+ bias_grad = torch.chunk(bias_grad, DEPTH)[j]
+ logger.info('Rank {} linear backward (bias_grad): {}'.format(
+ rank, check_equal(bias_grad, layer.bias.grad)))
+ # logger.info(f'\nRank {rank} Master:\n{layer_master.bias.grad}\nRank {rank} True:\n{bias_grad}\nRank {rank} Out:\n{layer.bias.grad}')
return fwd_end - fwd_start, bwd_end - bwd_start
def check_layernorm():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
- j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
- i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
- k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
+ input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
+
+ j = A_rank = global_context.get_local_rank(input_parallel_mode)
+ i = B_rank = global_context.get_local_rank(weight_parallel_mode)
+ k = C_rank = global_context.get_local_rank(output_parallel_mode)
norm = LAYERS.get_module('LayerNorm3D')(INPUT_SIZE,
- ParallelMode.PARALLEL_3D_INPUT,
- ParallelMode.PARALLEL_3D_WEIGHT,
+ # ParallelMode.PARALLEL_3D_INPUT,
+ # ParallelMode.PARALLEL_3D_WEIGHT,
eps=1e-6,
dtype=dtype)
norm = norm.to(device)
norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6)
norm_master = norm_master.to(device)
+ weight_master = norm_master.weight.data
+ torch.distributed.broadcast(weight_master, src=0)
+ weight = torch.chunk(weight_master, DEPTH)[k]
+ norm.weight = torch.nn.Parameter(weight)
+ bias_master = norm_master.bias.data
+ torch.distributed.broadcast(bias_master, src=0)
+ bias = torch.chunk(bias_master, DEPTH)[k]
+ norm.bias = torch.nn.Parameter(bias)
+
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
@@ -181,29 +204,15 @@ def check_layernorm():
logger.info('Rank {} layernorm backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
- if j == k:
- bias_grad = norm_master.weight.grad
- bias_grad = torch.chunk(bias_grad, DEPTH)[j]
- bias_grad = torch.chunk(bias_grad, DEPTH)[i]
- logger.info('Rank {} linear backward (weight_grad): {}'.format(
- rank, check_equal(bias_grad, norm.weight.grad)))
- else:
- logger.info('Rank {} linear backward (weight_grad): {}'.format(
- rank,
- # np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0))
- norm.weight.grad is None))
+ bias_grad = norm_master.weight.grad
+ bias_grad = torch.chunk(bias_grad, DEPTH)[k]
+ logger.info('Rank {} layernorm backward (weight_grad): {}'.format(
+ rank, check_equal(bias_grad, norm.weight.grad)))
- if j == k:
- bias_grad = norm_master.bias.grad
- bias_grad = torch.chunk(bias_grad, DEPTH)[j]
- bias_grad = torch.chunk(bias_grad, DEPTH)[i]
- logger.info('Rank {} linear backward (bias_grad): {}'.format(
- rank, check_equal(bias_grad, norm.bias.grad)))
- else:
- logger.info('Rank {} linear backward (bias_grad): {}'.format(
- rank,
- # np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0))
- norm.bias.grad is None))
+ bias_grad = norm_master.bias.grad
+ bias_grad = torch.chunk(bias_grad, DEPTH)[k]
+ logger.info('Rank {} layernorm backward (bias_grad): {}'.format(
+ rank, check_equal(bias_grad, norm.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
@@ -211,14 +220,18 @@ def check_layernorm():
def check_attention():
rank = torch.distributed.get_rank()
device = get_current_device()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
- j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
- i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
- k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
+ input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
+
+ j = A_rank = global_context.get_local_rank(input_parallel_mode)
+ i = B_rank = global_context.get_local_rank(weight_parallel_mode)
+ k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('ViTSelfAttention3D')(HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
@@ -264,13 +277,17 @@ def check_attention():
def check_mlp():
rank = torch.distributed.get_rank()
device = get_current_device()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
- j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
- i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
- k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
+ input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
+
+ j = A_rank = global_context.get_local_rank(input_parallel_mode)
+ i = B_rank = global_context.get_local_rank(weight_parallel_mode)
+ k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('ViTMLP3D')(HIDDEN_SIZE,
1,
@@ -320,28 +337,42 @@ class Testvithead(torch.nn.Module):
def check_head():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
- j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
- i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
- k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
+ input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
+
+ j = A_rank = global_context.get_local_rank(input_parallel_mode)
+ i = B_rank = global_context.get_local_rank(weight_parallel_mode)
+ k = C_rank = global_context.get_local_rank(output_parallel_mode)
head = LAYERS.get_module('ViTHead3D')(INPUT_SIZE,
NUM_CLASSES,
dtype=dtype,
bias=True)
- torch.nn.init.zeros_(head.linear.bias)
- torch.nn.init.ones_(head.linear.weight)
+ # torch.nn.init.zeros_(head.linear.bias)
+ # torch.nn.init.ones_(head.linear.weight)
head = head.to(device)
layer = Testvithead(INPUT_SIZE, NUM_CLASSES, bias=True)
- torch.nn.init.zeros_(layer.linear.bias)
- torch.nn.init.ones_(layer.linear.weight)
+ # torch.nn.init.zeros_(layer.linear.bias)
+ # torch.nn.init.ones_(layer.linear.weight)
layer = layer.to(device)
+ weight_master = layer.linear.weight.data.transpose(0, 1)
+ torch.distributed.broadcast(weight_master, src=0)
+ weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
+ weight = torch.chunk(weight, DEPTH, dim=-1)[j]
+ head.linear.weight = torch.nn.Parameter(weight)
+ bias_master = layer.linear.bias.data
+ torch.distributed.broadcast(bias_master, src=0)
+ bias = torch.chunk(bias_master, DEPTH)[j]
+ head.linear.bias = torch.nn.Parameter(bias)
+
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
@@ -397,31 +428,43 @@ def check_head():
B_grad = layer.linear.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
- pad_shape = (B_grad.shape[0], math.ceil(B_grad.shape[-1] / DEPTH) * DEPTH -
- B_grad.shape[-1])
- B_grad = torch.cat(
- [B_grad, torch.zeros(pad_shape, dtype=dtype, device=device)], dim=-1)
- B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
+ # B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
logger.info('Rank {} head backward (weight_grad): {}'.format(
rank, check_equal(B_grad, head.linear.weight.grad)))
- if j == k:
- bias_grad = layer.linear.bias.grad
- bias_grad = torch.chunk(bias_grad, DEPTH)[j]
- pad_shape = (math.ceil(bias_grad.shape[0] / DEPTH) * DEPTH -
- bias_grad.shape[0], )
- bias_grad = torch.cat(
- [bias_grad,
- torch.zeros(pad_shape, dtype=dtype, device=device)])
- bias_grad = torch.chunk(bias_grad, DEPTH)[i]
- logger.info('Rank {} head backward (bias_grad): {}'.format(
- rank, check_equal(bias_grad, head.linear.bias.grad)))
- else:
- logger.info('Rank {} head backward (bias_grad): {}'.format(
- rank,
- # np.count_nonzero(
- # head.linear.bias.grad.detach().cpu().numpy()) == 0))
- head.linear.bias.grad is None))
+ bias_grad = layer.linear.bias.grad
+ bias_grad = torch.chunk(bias_grad, DEPTH)[j]
+ logger.info('Rank {} head backward (bias_grad): {}'.format(
+ rank, check_equal(bias_grad, head.linear.bias.grad)))
+
+ # B_grad = layer.linear.weight.grad.transpose(0, 1)
+ # B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
+ # B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
+ # pad_shape = (B_grad.shape[0], math.ceil(B_grad.shape[-1] / DEPTH) * DEPTH -
+ # B_grad.shape[-1])
+ # B_grad = torch.cat(
+ # [B_grad, torch.zeros(pad_shape, dtype=dtype, device=device)], dim=-1)
+ # B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
+ # logger.info('Rank {} head backward (weight_grad): {}'.format(
+ # rank, check_equal(B_grad, head.linear.weight.grad)))
+
+ # if j == k:
+ # bias_grad = layer.linear.bias.grad
+ # bias_grad = torch.chunk(bias_grad, DEPTH)[j]
+ # pad_shape = (math.ceil(bias_grad.shape[0] / DEPTH) * DEPTH -
+ # bias_grad.shape[0], )
+ # bias_grad = torch.cat(
+ # [bias_grad,
+ # torch.zeros(pad_shape, dtype=dtype, device=device)])
+ # bias_grad = torch.chunk(bias_grad, DEPTH)[i]
+ # logger.info('Rank {} head backward (bias_grad): {}'.format(
+ # rank, check_equal(bias_grad, head.linear.bias.grad)))
+ # else:
+ # logger.info('Rank {} head backward (bias_grad): {}'.format(
+ # rank,
+ # # np.count_nonzero(
+ # # head.linear.bias.grad.detach().cpu().numpy()) == 0))
+ # head.linear.bias.grad is None))
return fwd_end - fwd_start, bwd_end - bwd_start
@@ -452,12 +495,16 @@ class Testvitembed(torch.nn.Module):
def check_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
dtype = torch.float32
- j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
- i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
- k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
+ input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
+
+ j = A_rank = global_context.get_local_rank(input_parallel_mode)
+ i = B_rank = global_context.get_local_rank(weight_parallel_mode)
+ k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('ViTPatchEmbedding3D')(IMG_SIZE, 4, 3,
HIDDEN_SIZE, 0.)
@@ -585,16 +632,20 @@ def check_embed():
def check_loss():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
- j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
- i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
- k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
+ input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+ output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
- criterion = LOSSES.get_module('CrossEntropyLoss3D')(
- ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT)
+ j = A_rank = global_context.get_local_rank(input_parallel_mode)
+ i = B_rank = global_context.get_local_rank(weight_parallel_mode)
+ k = C_rank = global_context.get_local_rank(output_parallel_mode)
+
+ criterion = LOSSES.get_module('CrossEntropyLoss3D')()
+ # ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT)
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
diff --git a/tests/test_layers/test_3d/test_operation.py b/tests/test_layers/test_3d/test_operation.py
index 05acb7f58..a0c34432c 100644
--- a/tests/test_layers/test_3d/test_operation.py
+++ b/tests/test_layers/test_3d/test_operation.py
@@ -3,7 +3,7 @@
from colossalai.context import ParallelMode
from colossalai.core import global_context
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
from colossalai.nn.layer.parallel_3d._operation import *
from colossalai.utils import get_current_device
@@ -12,7 +12,7 @@ from common import *
def check_AB():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
dtype = torch.float
j = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
@@ -83,7 +83,7 @@ def check_AB():
def check_ABT():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
@@ -152,7 +152,7 @@ def check_ABT():
def check_ATB():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
device = get_current_device()
dtype = torch.float
@@ -222,7 +222,7 @@ def check_ATB():
def check_add():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
@@ -296,7 +296,7 @@ def check_add():
def check_mul():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
@@ -370,7 +370,7 @@ def check_mul():
def check_sum():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
@@ -417,7 +417,7 @@ def check_sum():
def check_reduce():
rank = torch.distributed.get_rank()
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_layers/test_sequence/test_sequence.py
index 16122f93a..64a42a653 100644
--- a/tests/test_layers/test_sequence/test_sequence.py
+++ b/tests/test_layers/test_sequence/test_sequence.py
@@ -1,8 +1,8 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-from colossalai.initialize import init_dist
-from colossalai.logging import get_global_dist_logger
+from colossalai.initialize import launch, get_default_parser
+from colossalai.logging import get_dist_logger
from test_layer import *
CONFIG = dict(
@@ -19,11 +19,17 @@ def check_layer():
def _test_main():
# init dist
- init_dist(CONFIG)
- logger = get_global_dist_logger()
+ parser = get_default_parser()
+ args = parser.parse_args()
+ launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend)
+ logger = get_dist_logger()
logger.info('Distributed environment is initialzied.', ranks=[0])
- gpc.set_seed()
torch.backends.cudnn.benchmark = True
# check layers
diff --git a/tests/test_lr_scheduler/test_lr_scheduler.py b/tests/test_lr_scheduler/test_lr_scheduler.py
deleted file mode 100644
index 012ea4476..000000000
--- a/tests/test_lr_scheduler/test_lr_scheduler.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# from colossal.components.optimizer.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR
-# from colossal.components.optimizer.lr_scheduler import LinearWarmupLR
-# from colossal.components.optimizer.lr_scheduler import MultiStepLR, MultiStepWarmupLR
-# from colossal.components.optimizer.lr_scheduler import OneCycleLR
-# from colossal.components.optimizer.lr_scheduler import PolynomialLR, PolynomialWarmupLR
-import matplotlib.pyplot as plt
-import pytest
-from torch.optim import SGD
-from torchvision.models import resnet18
-
-from colossalai.builder import build_lr_scheduler
-
-NUM_EPOCHS = 5
-NUM_STEPS_PER_EPOCH = 10
-
-cfg = {
- 'warmup_steps': 5
-}
-
-
-def init_cfg(name, **kwargs):
- return {
- 'type': name,
- **cfg,
- **kwargs
- }
-
-
-def test_scheduler(optimizer, scheduler_name, **kwargs):
- for group in optimizer.param_groups:
- group['lr'] = 0.1
- config = init_cfg(scheduler_name, **kwargs)
- scheduler = build_lr_scheduler(config,
- optimizer, NUM_EPOCHS * NUM_STEPS_PER_EPOCH, NUM_STEPS_PER_EPOCH)
- x = []
- y = []
- for epoch in range(NUM_EPOCHS):
- for i in range(NUM_STEPS_PER_EPOCH):
- step = epoch * NUM_STEPS_PER_EPOCH + i
- lr = optimizer.param_groups[0]['lr']
- x.append(step)
- y.append(lr)
- scheduler.step()
- print(y)
- plt.plot(x, y)
- plt.show()
-
-
-@pytest.mark.skip("This test is skipped as it requires visualization, "
- "You can visualize the test output plots on your local environment")
-def test():
- model = resnet18()
- optimizer = SGD(model.parameters(), lr=1.0)
- test_scheduler(optimizer, 'CosineAnnealingLR')
- test_scheduler(optimizer, 'CosineAnnealingWarmupLR')
- test_scheduler(optimizer, 'FlatAnnealingLR')
- test_scheduler(optimizer, 'FlatAnnealingWarmupLR')
- test_scheduler(optimizer, 'LinearWarmupLR')
- test_scheduler(optimizer, 'MultiStepLR', milestones=[1, 3])
- test_scheduler(optimizer, 'MultiStepWarmupLR', milestones=[1, 3])
- test_scheduler(optimizer, 'MultiStepWarmupLR',
- milestones=[1, 3], warmup_epochs=1)
- test_scheduler(optimizer, 'PolynomialLR', power=2.0)
- test_scheduler(optimizer, 'PolynomialWarmupLR', power=2.0)
- test_scheduler(optimizer, 'OneCycleLR')
-
-
-if __name__ == '__main__':
- test()
diff --git a/tests/test_models/test_vanilla_resnet/test_vanilla_resnet.py b/tests/test_models/test_vanilla_resnet/test_vanilla_resnet.py
deleted file mode 100644
index bc9144fe0..000000000
--- a/tests/test_models/test_vanilla_resnet/test_vanilla_resnet.py
+++ /dev/null
@@ -1,98 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import pytest
-import torch
-import torchvision.models as models
-
-from colossalai.builder import build_model
-
-NUM_CLS = 10
-
-RESNET18 = dict(
- type='VanillaResNet',
- block_type='ResNetBasicBlock',
- layers=[2, 2, 2, 2],
- num_cls=NUM_CLS
-)
-
-RESNET34 = dict(
- type='VanillaResNet',
- block_type='ResNetBasicBlock',
- layers=[3, 4, 6, 3],
- num_cls=NUM_CLS
-)
-
-RESNET50 = dict(
- type='VanillaResNet',
- block_type='ResNetBottleneck',
- layers=[3, 4, 6, 3],
- num_cls=NUM_CLS
-)
-
-RESNET101 = dict(
- type='VanillaResNet',
- block_type='ResNetBottleneck',
- layers=[3, 4, 23, 3],
- num_cls=NUM_CLS
-)
-
-RESNET152 = dict(
- type='VanillaResNet',
- block_type='ResNetBottleneck',
- layers=[3, 8, 36, 3],
- num_cls=NUM_CLS
-)
-
-
-def compare_model(data, colossal_model, torchvision_model):
- colossal_output = colossal_model(data)
- torchvision_output = torchvision_model(data)
- assert colossal_output[
- 0].shape == torchvision_output.shape, f'{colossal_output[0].shape}, {torchvision_output.shape}'
-
-
-@pytest.mark.cpu
-def test_vanilla_resnet():
- """Compare colossal resnet with torchvision resnet"""
- # data
- x = torch.randn((2, 3, 224, 224))
-
- # resnet 18
- col_resnet18 = build_model(RESNET18)
- col_resnet18.build_from_cfg()
- torchvision_resnet18 = models.resnet18(num_classes=NUM_CLS)
-
- compare_model(x, col_resnet18, torchvision_resnet18)
-
- # resnet 34
- col_resnet34 = build_model(RESNET34)
- col_resnet34.build_from_cfg()
- torchvision_resnet34 = models.resnet34(num_classes=NUM_CLS)
-
- compare_model(x, col_resnet34, torchvision_resnet34)
-
- # resnet 50
- col_resnet50 = build_model(RESNET50)
- col_resnet50.build_from_cfg()
- torchvision_resnet50 = models.resnet50(num_classes=NUM_CLS)
-
- compare_model(x, col_resnet50, torchvision_resnet50)
-
- # resnet 101
- col_resnet101 = build_model(RESNET101)
- col_resnet101.build_from_cfg()
- torchvision_resnet101 = models.resnet101(num_classes=NUM_CLS)
-
- compare_model(x, col_resnet101, torchvision_resnet101)
-
- # # resnet 152
- col_resnet152 = build_model(RESNET152)
- col_resnet152.build_from_cfg()
- torchvision_resnet152 = models.resnet152(num_classes=NUM_CLS)
-
- compare_model(x, col_resnet152, torchvision_resnet152)
-
-
-if __name__ == '__main__':
- test_vanilla_resnet()
diff --git a/tests/test_models/test_vision_transformer/configs/vit_2d.py b/tests/test_models/test_vision_transformer/configs/vit_2d.py
deleted file mode 100644
index 1fd1102fb..000000000
--- a/tests/test_models/test_vision_transformer/configs/vit_2d.py
+++ /dev/null
@@ -1,107 +0,0 @@
-import os
-from pathlib import Path
-
-BATCH_SIZE = 512
-IMG_SIZE = 32
-PATCH_SIZE = 4
-DIM = 512
-NUM_ATTENTION_HEADS = 8
-SUMMA_DIM = 2
-NUM_CLASSES = 10
-DEPTH = 6
-
-train_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- transform_pipeline=[
- dict(type='RandomCrop', size=IMG_SIZE, padding=4),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]),
- dataloader=dict(batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True))
-
-test_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]),
- dataloader=dict(batch_size=400,
- pin_memory=True,
- num_workers=4,
- shuffle=True))
-
-optimizer = dict(type='Adam', lr=0.001, weight_decay=0)
-
-loss = dict(type='CrossEntropyLoss2D', )
-
-model = dict(
- type='VisionTransformerFromConfig',
- tensor_splitting_cfg=dict(type='ViTInputSplitter2D', ),
- embedding_cfg=dict(
- type='ViTPatchEmbedding2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- ),
- token_fusion_cfg=dict(type='ViTTokenFuser2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- drop_rate=0.1),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- block_cfg=dict(
- type='ViTBlock',
- attention_cfg=dict(
- type='ViTSelfAttention2D',
- hidden_size=DIM,
- num_attention_heads=NUM_ATTENTION_HEADS,
- attention_dropout_prob=0.,
- hidden_dropout_prob=0.1,
- ),
- droppath_cfg=dict(type='VanillaViTDropPath', ),
- mlp_cfg=dict(type='ViTMLP2D',
- in_features=DIM,
- dropout_prob=0.1,
- mlp_ratio=1),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- ),
- head_cfg=dict(
- type='ViTHead2D',
- hidden_size=DIM,
- num_classes=NUM_CLASSES,
- ),
- embed_dim=DIM,
- depth=DEPTH,
- drop_path_rate=0.,
-)
-
-parallel = dict(
- pipeline=dict(size=1),
- tensor=dict(size=4, mode='2d'),
-)
-
-num_epochs = 60
-
-lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs)
diff --git a/tests/test_models/test_vision_transformer/configs/vit_2p5d.py b/tests/test_models/test_vision_transformer/configs/vit_2p5d.py
deleted file mode 100644
index 3c16d684a..000000000
--- a/tests/test_models/test_vision_transformer/configs/vit_2p5d.py
+++ /dev/null
@@ -1,130 +0,0 @@
-import os
-from pathlib import Path
-
-BATCH_SIZE = 512
-IMG_SIZE = 32
-PATCH_SIZE = 4
-DIM = 512
-NUM_ATTENTION_HEADS = 8
-SUMMA_DIM = 2
-NUM_CLASSES = 10
-DEPTH = 6
-
-train_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- transform_pipeline=[
- dict(type='RandomCrop', size=IMG_SIZE, padding=4),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=0,
- shuffle=True
- )
-)
-
-test_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]
- ),
- ]
- ),
- dataloader=dict(
- batch_size=400,
- pin_memory=True,
- num_workers=0,
- shuffle=True
- )
-)
-
-optimizer = dict(
- type='Adam',
- lr=0.001,
- weight_decay=0
-)
-
-loss = dict(
- type='CrossEntropyLoss2p5D',
-)
-
-model = dict(
- type='VisionTransformerFromConfig',
- tensor_splitting_cfg=dict(
- type='ViTInputSplitter2p5D',
- ),
- embedding_cfg=dict(
- type='ViTPatchEmbedding2p5D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- ),
- token_fusion_cfg=dict(
- type='ViTTokenFuser2p5D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- drop_rate=0.1
- ),
- norm_cfg=dict(
- type='LayerNorm2p5D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- block_cfg=dict(
- type='ViTBlock',
- attention_cfg=dict(
- type='ViTSelfAttention2p5D',
- hidden_size=DIM,
- num_attention_heads=NUM_ATTENTION_HEADS,
- attention_dropout_prob=0.,
- hidden_dropout_prob=0.1,
- ),
- droppath_cfg=dict(
- type='VanillaViTDropPath',
- ),
- mlp_cfg=dict(
- type='ViTMLP2p5D',
- in_features=DIM,
- dropout_prob=0.1,
- mlp_ratio=1
- ),
- norm_cfg=dict(
- type='LayerNorm2p5D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- ),
- head_cfg=dict(
- type='ViTHead2p5D',
- hidden_size=DIM,
- num_classes=NUM_CLASSES,
- ),
- embed_dim=DIM,
- depth=DEPTH,
- drop_path_rate=0.,
-)
-
-parallel = dict(
- pipeline=dict(size=1),
- tensor=dict(size=4, depth=1, mode='2.5d'),
-)
-
-num_epochs = 60
-
-lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5, total_steps=num_epochs)
diff --git a/tests/test_models/test_vision_transformer/configs/vit_3d.py b/tests/test_models/test_vision_transformer/configs/vit_3d.py
deleted file mode 100644
index ad041efd0..000000000
--- a/tests/test_models/test_vision_transformer/configs/vit_3d.py
+++ /dev/null
@@ -1,135 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import os
-from pathlib import Path
-
-from colossalai.context import ParallelMode
-
-IMG_SIZE = 32
-PATCH_SIZE = 4
-EMBED_SIZE = 512
-HIDDEN_SIZE = 512
-NUM_HEADS = 8
-NUM_CLASSES = 10
-NUM_BLOCKS = 6
-DROP_RATE = 0.1
-BATCH_SIZE = 512
-LEARNING_RATE = 0.001
-DATASET_PATH = Path(os.environ['DATA'])
-
-model = dict(
- type='VisionTransformerFromConfig',
- embedding_cfg=dict(
- type='ViTPatchEmbedding3D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- in_chans=3,
- embed_size=EMBED_SIZE,
- drop_prob=DROP_RATE,
- ),
- block_cfg=dict(
- type='ViTBlock',
- norm_cfg=dict(
- type='LayerNorm3D',
- normalized_shape=HIDDEN_SIZE,
- eps=1e-6,
- input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT,
- weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
- ),
- attention_cfg=dict(
- type='ViTSelfAttention3D',
- hidden_size=HIDDEN_SIZE,
- num_attention_heads=NUM_HEADS,
- attention_probs_dropout_prob=0.,
- hidden_dropout_prob=DROP_RATE,
- ),
- droppath_cfg=dict(type='VanillaViTDropPath', ),
- mlp_cfg=dict(
- type='ViTMLP3D',
- hidden_size=HIDDEN_SIZE,
- mlp_ratio=1,
- hidden_dropout_prob=DROP_RATE,
- hidden_act='gelu',
- ),
- ),
- norm_cfg=dict(type='LayerNorm3D',
- normalized_shape=HIDDEN_SIZE,
- eps=1e-6,
- input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT,
- weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT),
- head_cfg=dict(
- type='ViTHead3D',
- in_features=HIDDEN_SIZE,
- num_classes=NUM_CLASSES,
- ),
- embed_dim=HIDDEN_SIZE,
- depth=NUM_BLOCKS,
- drop_path_rate=0.,
-)
-
-loss = dict(type='CrossEntropyLoss3D',
- input_parallel_mode=ParallelMode.PARALLEL_3D_OUTPUT,
- weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
- reduction=True)
-
-optimizer = dict(type='Adam', lr=LEARNING_RATE, weight_decay=0)
-
-train_data = dict(dataset=dict(type='CIFAR10Dataset',
- root=DATASET_PATH,
- transform_pipeline=[
- dict(type='RandomCrop',
- size=IMG_SIZE,
- padding=4),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]),
- dataloader=dict(batch_size=BATCH_SIZE,
- pin_memory=True,
- shuffle=True,
- num_workers=8))
-
-test_data = dict(dataset=dict(type='CIFAR10Dataset',
- root=DATASET_PATH,
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]),
- dataloader=dict(batch_size=400,
- pin_memory=True,
- num_workers=8))
-
-hooks = [
- dict(type='LogMetricByEpochHook'),
- dict(type='LogTimingByEpochHook'),
- dict(type='LogMemoryByEpochHook'),
- dict(
- type='Accuracy3DHook',
- input_parallel_mode=ParallelMode.PARALLEL_3D_OUTPUT,
- weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT,
- ),
- dict(type='LossHook'),
- dict(
- type='LRSchedulerHook',
- by_epoch=True,
- lr_scheduler_cfg=dict(
- type='LinearWarmupLR',
- warmup_steps=5
- )
- ),
-]
-
-parallel = dict(
- data=1,
- pipeline=1,
- tensor=dict(mode='3d', size=8),
-)
-
-num_epochs = 60
diff --git a/tests/test_models/test_vision_transformer/configs/vit_vanilla.py b/tests/test_models/test_vision_transformer/configs/vit_vanilla.py
deleted file mode 100644
index 7602fd0c8..000000000
--- a/tests/test_models/test_vision_transformer/configs/vit_vanilla.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import torch.nn as nn
-
-IMG_SIZE = 224
-DIM = 768
-NUM_CLASSES = 1000
-NUM_ATTN_HEADS = 12
-
-model = dict(
- type='VisionTransformerFromConfig',
- embedding_cfg=dict(
- type='VanillaViTPatchEmbedding',
- img_size=IMG_SIZE,
- patch_size=16,
- in_chans=3,
- embed_dim=DIM
- ),
- norm_cfg=dict(
- type='LayerNorm',
- eps=1e-6,
- normalized_shape=DIM
- ),
- block_cfg=dict(
- type='ViTBlock',
- checkpoint=True,
- attention_cfg=dict(
- type='VanillaViTAttention',
- dim=DIM,
- num_heads=NUM_ATTN_HEADS,
- qkv_bias=True,
- attn_drop=0.,
- proj_drop=0.
- ),
- droppath_cfg=dict(
- type='VanillaViTDropPath',
- ),
- mlp_cfg=dict(
- type='VanillaViTMLP',
- in_features=DIM,
- hidden_features=DIM * 4,
- act_layer=nn.GELU,
- drop=0.
- ),
- norm_cfg=dict(
- type='LayerNorm',
- normalized_shape=DIM
- ),
- ),
- head_cfg=dict(
- type='VanillaViTHead',
- in_features=DIM,
- intermediate_features=DIM * 2,
- out_features=NUM_CLASSES
- ),
- depth=12,
- drop_path_rate=0.,
-)
diff --git a/tests/test_models/test_vision_transformer/test.sh b/tests/test_models/test_vision_transformer/test.sh
deleted file mode 100644
index 1c6012a52..000000000
--- a/tests/test_models/test_vision_transformer/test.sh
+++ /dev/null
@@ -1,4 +0,0 @@
-#!/usr/bin/env sh
-test_file=$1
-
-python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-3/acc-2D-lr1e-3.jpg b/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-3/acc-2D-lr1e-3.jpg
deleted file mode 100644
index 541ef9c55..000000000
Binary files a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-3/acc-2D-lr1e-3.jpg and /dev/null differ
diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-3/alignment.o3475503 b/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-3/alignment.o3475503
deleted file mode 100644
index 87b8ac8d2..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-3/alignment.o3475503
+++ /dev/null
@@ -1,177 +0,0 @@
-Tue Aug 31 23:19:11 CDT 2021
-TACC: Starting up job 3475503
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-model is created
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 3 is bound to device 3
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 2 is bound to device 2
-Files already downloaded and verified
-Files already downloaded and verified
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-optimizer is created
-start training
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 1 is bound to device 1
-Files already downloaded and verified
-Files already downloaded and verified
-epoch: 0, train loss: 1.9477875804414555
-epoch: 0, eval loss: 1.8044581711292267, correct: 3584, total: 10000, acc = 0.35839998722076416
-epoch: 1, train loss: 1.6170539077447386
-epoch: 2, train loss: 1.4285206673096638
-epoch: 2, eval loss: 1.3448221325874328, correct: 5167, total: 10000, acc = 0.5166999697685242
-epoch: 3, train loss: 1.3227316007322194
-epoch: 4, train loss: 1.2758926104526132
-epoch: 4, eval loss: 1.2780393660068512, correct: 5460, total: 10000, acc = 0.5460000038146973
-epoch: 5, train loss: 1.2221618829941263
-epoch: 6, train loss: 1.1857815640313285
-epoch: 6, eval loss: 1.1175921618938447, correct: 6023, total: 10000, acc = 0.6022999882698059
-epoch: 7, train loss: 1.1659576710389585
-epoch: 8, train loss: 1.1457150134505059
-epoch: 8, eval loss: 1.0789835333824158, correct: 6113, total: 10000, acc = 0.611299991607666
-epoch: 9, train loss: 1.1156543700062498
-epoch: 10, train loss: 1.0950473242876482
-epoch: 10, eval loss: 1.058586174249649, correct: 6170, total: 10000, acc = 0.6169999837875366
-epoch: 11, train loss: 1.0976360866001673
-epoch: 12, train loss: 1.07803391193857
-epoch: 12, eval loss: 1.0039635241031646, correct: 6351, total: 10000, acc = 0.6351000070571899
-epoch: 13, train loss: 1.0680764615535736
-epoch: 14, train loss: 1.0364759442757587
-epoch: 14, eval loss: 0.9748250603675842, correct: 6486, total: 10000, acc = 0.6485999822616577
-epoch: 15, train loss: 1.023898609438721
-epoch: 16, train loss: 0.9982165353638786
-epoch: 16, eval loss: 0.9612966269254685, correct: 6591, total: 10000, acc = 0.6590999960899353
-epoch: 17, train loss: 0.9698412771127662
-epoch: 18, train loss: 0.9523191050607331
-epoch: 18, eval loss: 0.8974281877279282, correct: 6810, total: 10000, acc = 0.6809999942779541
-epoch: 19, train loss: 0.9171817661548147
-epoch: 20, train loss: 0.8905259948603961
-epoch: 20, eval loss: 0.8580602705478668, correct: 6965, total: 10000, acc = 0.6965000033378601
-epoch: 21, train loss: 0.86673782917918
-epoch: 22, train loss: 0.8339344001546198
-epoch: 22, eval loss: 0.8263293951749802, correct: 7107, total: 10000, acc = 0.7106999754905701
-epoch: 23, train loss: 0.8074834510988119
-epoch: 24, train loss: 0.7840324482139276
-epoch: 24, eval loss: 0.752952727675438, correct: 7317, total: 10000, acc = 0.7317000031471252
-epoch: 25, train loss: 0.7541018596717289
-epoch: 26, train loss: 0.7357191905683401
-epoch: 26, eval loss: 0.7338999301195145, correct: 7410, total: 10000, acc = 0.7409999966621399
-epoch: 27, train loss: 0.7107210451242875
-epoch: 28, train loss: 0.6785972909051545
-epoch: 28, eval loss: 0.7020785599946976, correct: 7523, total: 10000, acc = 0.752299964427948
-epoch: 29, train loss: 0.660102152094549
-epoch: 30, train loss: 0.6498027924372225
-epoch: 30, eval loss: 0.6610008627176285, correct: 7661, total: 10000, acc = 0.7660999894142151
-epoch: 31, train loss: 0.6297167344969146
-epoch: 32, train loss: 0.6150159224563715
-epoch: 32, eval loss: 0.6350889533758164, correct: 7802, total: 10000, acc = 0.7802000045776367
-epoch: 33, train loss: 0.5912032842027898
-epoch: 34, train loss: 0.5761601137263435
-epoch: 34, eval loss: 0.6296706795692444, correct: 7786, total: 10000, acc = 0.7785999774932861
-epoch: 35, train loss: 0.5586571322411907
-epoch: 36, train loss: 0.5488096165413759
-epoch: 36, eval loss: 0.6041992783546448, correct: 7913, total: 10000, acc = 0.7912999987602234
-epoch: 37, train loss: 0.5273334958723613
-epoch: 38, train loss: 0.5074144468015555
-epoch: 38, eval loss: 0.5868680268526077, correct: 7984, total: 10000, acc = 0.7983999848365784
-epoch: 39, train loss: 0.4930413775906271
-epoch: 40, train loss: 0.47384805977344513
-epoch: 40, eval loss: 0.6013937592506409, correct: 7945, total: 10000, acc = 0.7944999933242798
-epoch: 41, train loss: 0.4618621742238804
-epoch: 42, train loss: 0.4452754973757024
-epoch: 42, eval loss: 0.5606920897960663, correct: 8093, total: 10000, acc = 0.8093000054359436
-epoch: 43, train loss: 0.4361336164328517
-epoch: 44, train loss: 0.4188923318775333
-epoch: 44, eval loss: 0.5567828729748726, correct: 8042, total: 10000, acc = 0.8041999936103821
-epoch: 45, train loss: 0.4047189655960823
-epoch: 46, train loss: 0.3873833852763079
-epoch: 46, eval loss: 0.5404785141348839, correct: 8166, total: 10000, acc = 0.81659996509552
-epoch: 47, train loss: 0.3707445412874222
-epoch: 48, train loss: 0.3631058514726405
-epoch: 48, eval loss: 0.5541519388556481, correct: 8201, total: 10000, acc = 0.820099949836731
-epoch: 49, train loss: 0.34395075604623676
-epoch: 50, train loss: 0.3290589987015238
-epoch: 50, eval loss: 0.5442438080906868, correct: 8169, total: 10000, acc = 0.8168999552726746
-epoch: 51, train loss: 0.3188562990755451
-epoch: 52, train loss: 0.2986554713273535
-epoch: 52, eval loss: 0.5515974283218383, correct: 8203, total: 10000, acc = 0.8202999830245972
-epoch: 53, train loss: 0.29044121671087886
-epoch: 54, train loss: 0.27310980613134345
-epoch: 54, eval loss: 0.5587902516126633, correct: 8195, total: 10000, acc = 0.8194999694824219
-epoch: 55, train loss: 0.2637303553673686
-epoch: 56, train loss: 0.2521531299060705
-epoch: 56, eval loss: 0.5885633528232574, correct: 8202, total: 10000, acc = 0.8201999664306641
-epoch: 57, train loss: 0.23304983274060853
-epoch: 58, train loss: 0.22784664101746618
-epoch: 58, eval loss: 0.5882876932621002, correct: 8245, total: 10000, acc = 0.8244999647140503
-epoch: 59, train loss: 0.21604956868959932
-epoch: 60, train loss: 0.20325114882113982
-epoch: 60, eval loss: 0.5910753712058068, correct: 8248, total: 10000, acc = 0.8247999548912048
-epoch: 61, train loss: 0.19390226033877353
-epoch: 62, train loss: 0.18323212360240976
-epoch: 62, eval loss: 0.6264512360095977, correct: 8272, total: 10000, acc = 0.8271999955177307
-epoch: 63, train loss: 0.1680474430322647
-epoch: 64, train loss: 0.16121925512442783
-epoch: 64, eval loss: 0.640467157959938, correct: 8283, total: 10000, acc = 0.8282999992370605
-epoch: 65, train loss: 0.14981685054241395
-epoch: 66, train loss: 0.14731310475237516
-epoch: 66, eval loss: 0.6354441046714783, correct: 8303, total: 10000, acc = 0.830299973487854
-epoch: 67, train loss: 0.13300996874364054
-epoch: 68, train loss: 0.12739452506814683
-epoch: 68, eval loss: 0.6673313498497009, correct: 8282, total: 10000, acc = 0.8281999826431274
-epoch: 69, train loss: 0.11627298555507952
-epoch: 70, train loss: 0.10940710728874012
-epoch: 70, eval loss: 0.692647397518158, correct: 8302, total: 10000, acc = 0.8301999568939209
-epoch: 71, train loss: 0.10183572788171623
-epoch: 72, train loss: 0.09634554986746943
-epoch: 72, eval loss: 0.695426219701767, correct: 8299, total: 10000, acc = 0.8298999667167664
-epoch: 73, train loss: 0.09228058896806775
-epoch: 74, train loss: 0.08581420976896675
-epoch: 74, eval loss: 0.694861987233162, correct: 8340, total: 10000, acc = 0.8339999914169312
-epoch: 75, train loss: 0.07914869715364611
-epoch: 76, train loss: 0.0742536057547039
-epoch: 76, eval loss: 0.7130348771810532, correct: 8347, total: 10000, acc = 0.8346999883651733
-epoch: 77, train loss: 0.06935026941402835
-epoch: 78, train loss: 0.0665280031306403
-epoch: 78, eval loss: 0.7465721786022186, correct: 8341, total: 10000, acc = 0.8341000080108643
-epoch: 79, train loss: 0.05928862589050313
-epoch: 80, train loss: 0.05455683164146482
-epoch: 80, eval loss: 0.776301947236061, correct: 8314, total: 10000, acc = 0.8313999772071838
-epoch: 81, train loss: 0.05638634926658504
-epoch: 82, train loss: 0.05360411343221762
-epoch: 82, eval loss: 0.7883096963167191, correct: 8332, total: 10000, acc = 0.8331999778747559
-epoch: 83, train loss: 0.04867944630737207
-epoch: 84, train loss: 0.0474467751931171
-epoch: 84, eval loss: 0.7960963994264603, correct: 8340, total: 10000, acc = 0.8339999914169312
-epoch: 85, train loss: 0.044186076149344444
-epoch: 86, train loss: 0.041499203527156185
-epoch: 86, eval loss: 0.7910951197147369, correct: 8364, total: 10000, acc = 0.8363999724388123
-epoch: 87, train loss: 0.03865447499770291
-epoch: 88, train loss: 0.03929391104195799
-epoch: 88, eval loss: 0.8027831196784974, correct: 8371, total: 10000, acc = 0.8370999693870544
-epoch: 89, train loss: 0.03619915826664287
-epoch: 90, train loss: 0.034103386617284646
-epoch: 90, eval loss: 0.8021850943565368, correct: 8357, total: 10000, acc = 0.8356999754905701
-epoch: 91, train loss: 0.037686211741244306
-epoch: 92, train loss: 0.033469487000636906
-epoch: 92, eval loss: 0.8143919885158539, correct: 8359, total: 10000, acc = 0.8359000086784363
-epoch: 93, train loss: 0.03238337778733397
-epoch: 94, train loss: 0.03141044888987529
-epoch: 94, eval loss: 0.8093269526958465, correct: 8385, total: 10000, acc = 0.8384999632835388
-epoch: 95, train loss: 0.031111840225223984
-epoch: 96, train loss: 0.032168653251945366
-epoch: 96, eval loss: 0.8102991491556167, correct: 8379, total: 10000, acc = 0.8378999829292297
-epoch: 97, train loss: 0.03043455306478605
-epoch: 98, train loss: 0.03200419174925405
-epoch: 98, eval loss: 0.8105081558227539, correct: 8373, total: 10000, acc = 0.8373000025749207
-epoch: 99, train loss: 0.031662615329711416
-finish training
-TACC: Shutdown complete. Exiting.
diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-3/loss-2D-lr1e-3.jpg b/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-3/loss-2D-lr1e-3.jpg
deleted file mode 100644
index ced7315a3..000000000
Binary files a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-3/loss-2D-lr1e-3.jpg and /dev/null differ
diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-4/acc-2D-lr1e-4.jpg b/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-4/acc-2D-lr1e-4.jpg
deleted file mode 100644
index d5546561a..000000000
Binary files a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-4/acc-2D-lr1e-4.jpg and /dev/null differ
diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-4/alignment.o3472937 b/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-4/alignment.o3472937
deleted file mode 100644
index ccda9881c..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-4/alignment.o3472937
+++ /dev/null
@@ -1,177 +0,0 @@
-Tue Aug 31 12:28:41 CDT 2021
-TACC: Starting up job 3472937
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-model is created
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 3 is bound to device 3
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 2 is bound to device 2
-Files already downloaded and verified
-Files already downloaded and verified
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 1 is bound to device 1
-Files already downloaded and verified
-Files already downloaded and verified
-optimizer is created
-start training
-epoch: 0, train loss: 2.0869219929901597
-epoch: 0, eval loss: 1.9415993988513947, correct: 2875, total: 10000, acc = 0.2874999940395355
-epoch: 1, train loss: 1.832751432644952
-epoch: 2, train loss: 1.6953342194409715
-epoch: 2, eval loss: 1.6502822101116181, correct: 4026, total: 10000, acc = 0.4025999903678894
-epoch: 3, train loss: 1.583214813900977
-epoch: 4, train loss: 1.4921425851349979
-epoch: 4, eval loss: 1.4688542783260345, correct: 4773, total: 10000, acc = 0.4772999882698059
-epoch: 5, train loss: 1.3872919402171655
-epoch: 6, train loss: 1.3123903028743784
-epoch: 6, eval loss: 1.328972715139389, correct: 5275, total: 10000, acc = 0.5274999737739563
-epoch: 7, train loss: 1.2454541012183906
-epoch: 8, train loss: 1.1953427422906935
-epoch: 8, eval loss: 1.2376527905464172, correct: 5579, total: 10000, acc = 0.5579000115394592
-epoch: 9, train loss: 1.1491977308214325
-epoch: 10, train loss: 1.1148795012346249
-epoch: 10, eval loss: 1.1297606527805328, correct: 5975, total: 10000, acc = 0.5974999666213989
-epoch: 11, train loss: 1.076469630310216
-epoch: 12, train loss: 1.0476364874348199
-epoch: 12, eval loss: 1.029269078373909, correct: 6333, total: 10000, acc = 0.6333000063896179
-epoch: 13, train loss: 1.0117879393174476
-epoch: 14, train loss: 0.9859390357106003
-epoch: 14, eval loss: 0.97474505007267, correct: 6494, total: 10000, acc = 0.649399995803833
-epoch: 15, train loss: 0.9595183336857668
-epoch: 16, train loss: 0.9384779051407096
-epoch: 16, eval loss: 0.9172703564167023, correct: 6716, total: 10000, acc = 0.6715999841690063
-epoch: 17, train loss: 0.9127772370564569
-epoch: 18, train loss: 0.889132705545917
-epoch: 18, eval loss: 0.8939311623573303, correct: 6809, total: 10000, acc = 0.680899977684021
-epoch: 19, train loss: 0.8719241373317758
-epoch: 20, train loss: 0.8456196920158937
-epoch: 20, eval loss: 0.8584266930818558, correct: 6944, total: 10000, acc = 0.6943999528884888
-epoch: 21, train loss: 0.8258345379042871
-epoch: 22, train loss: 0.8185748826597155
-epoch: 22, eval loss: 0.8427778095006943, correct: 7020, total: 10000, acc = 0.7019999623298645
-epoch: 23, train loss: 0.794703829534275
-epoch: 24, train loss: 0.777785701235545
-epoch: 24, eval loss: 0.801164984703064, correct: 7182, total: 10000, acc = 0.7181999683380127
-epoch: 25, train loss: 0.760752295095896
-epoch: 26, train loss: 0.7453707229230822
-epoch: 26, eval loss: 0.7841533124446869, correct: 7209, total: 10000, acc = 0.7208999991416931
-epoch: 27, train loss: 0.7267675215436011
-epoch: 28, train loss: 0.7131575210807249
-epoch: 28, eval loss: 0.7708685129880906, correct: 7254, total: 10000, acc = 0.7253999710083008
-epoch: 29, train loss: 0.7007347437524304
-epoch: 30, train loss: 0.6834727574869529
-epoch: 30, eval loss: 0.7591335833072662, correct: 7356, total: 10000, acc = 0.7355999946594238
-epoch: 31, train loss: 0.6712760894568925
-epoch: 32, train loss: 0.655675129177644
-epoch: 32, eval loss: 0.7655339151620865, correct: 7314, total: 10000, acc = 0.7313999533653259
-epoch: 33, train loss: 0.6421149262447947
-epoch: 34, train loss: 0.6301654601834484
-epoch: 34, eval loss: 0.7450480967760086, correct: 7350, total: 10000, acc = 0.73499995470047
-epoch: 35, train loss: 0.6189313580080406
-epoch: 36, train loss: 0.6047559282214371
-epoch: 36, eval loss: 0.7468931972980499, correct: 7392, total: 10000, acc = 0.7391999959945679
-epoch: 37, train loss: 0.5878085592358383
-epoch: 38, train loss: 0.5731440121980057
-epoch: 38, eval loss: 0.7349929332733154, correct: 7434, total: 10000, acc = 0.743399977684021
-epoch: 39, train loss: 0.5633921856732712
-epoch: 40, train loss: 0.5499549056451345
-epoch: 40, eval loss: 0.7258913427591324, correct: 7483, total: 10000, acc = 0.7482999563217163
-epoch: 41, train loss: 0.5403583102005044
-epoch: 42, train loss: 0.5286270272485989
-epoch: 42, eval loss: 0.7170430123806, correct: 7528, total: 10000, acc = 0.7527999877929688
-epoch: 43, train loss: 0.5166667195939526
-epoch: 44, train loss: 0.5098928068716502
-epoch: 44, eval loss: 0.7244295090436935, correct: 7531, total: 10000, acc = 0.7530999779701233
-epoch: 45, train loss: 0.4917362458312634
-epoch: 46, train loss: 0.48251094676784634
-epoch: 46, eval loss: 0.728115001320839, correct: 7557, total: 10000, acc = 0.7556999921798706
-epoch: 47, train loss: 0.47845434067175563
-epoch: 48, train loss: 0.4637242813700253
-epoch: 48, eval loss: 0.7259155690670014, correct: 7559, total: 10000, acc = 0.755899965763092
-epoch: 49, train loss: 0.4557308668328315
-epoch: 50, train loss: 0.4414065560114752
-epoch: 50, eval loss: 0.7056828439235687, correct: 7648, total: 10000, acc = 0.764799952507019
-epoch: 51, train loss: 0.43054875792916286
-epoch: 52, train loss: 0.4196087404624703
-epoch: 52, eval loss: 0.7131796926259995, correct: 7670, total: 10000, acc = 0.7669999599456787
-epoch: 53, train loss: 0.41613124971537246
-epoch: 54, train loss: 0.4016842494920357
-epoch: 54, eval loss: 0.7215427845716477, correct: 7641, total: 10000, acc = 0.7640999555587769
-epoch: 55, train loss: 0.39098499054761277
-epoch: 56, train loss: 0.3805114098430909
-epoch: 56, eval loss: 0.7281092345714569, correct: 7672, total: 10000, acc = 0.7671999931335449
-epoch: 57, train loss: 0.3724562412070245
-epoch: 58, train loss: 0.37037558162335266
-epoch: 58, eval loss: 0.7282122701406479, correct: 7707, total: 10000, acc = 0.7706999778747559
-epoch: 59, train loss: 0.3584493641386327
-epoch: 60, train loss: 0.35091825858833864
-epoch: 60, eval loss: 0.7441833585500717, correct: 7653, total: 10000, acc = 0.7652999758720398
-epoch: 61, train loss: 0.3469349926279992
-epoch: 62, train loss: 0.33631756533052504
-epoch: 62, eval loss: 0.7398703306913376, correct: 7679, total: 10000, acc = 0.7678999900817871
-epoch: 63, train loss: 0.3287597510618033
-epoch: 64, train loss: 0.3201854192104536
-epoch: 64, eval loss: 0.7452850311994552, correct: 7676, total: 10000, acc = 0.7675999999046326
-epoch: 65, train loss: 0.3196122018025093
-epoch: 66, train loss: 0.30826768893556494
-epoch: 66, eval loss: 0.7634350836277009, correct: 7637, total: 10000, acc = 0.763700008392334
-epoch: 67, train loss: 0.30273781855081777
-epoch: 68, train loss: 0.29493943732423883
-epoch: 68, eval loss: 0.7755635917186737, correct: 7679, total: 10000, acc = 0.7678999900817871
-epoch: 69, train loss: 0.2938486310010104
-epoch: 70, train loss: 0.2860709805156767
-epoch: 70, eval loss: 0.7754869312047958, correct: 7690, total: 10000, acc = 0.7689999938011169
-epoch: 71, train loss: 0.27915918873142953
-epoch: 72, train loss: 0.2728954671891694
-epoch: 72, eval loss: 0.78065524995327, correct: 7693, total: 10000, acc = 0.7692999839782715
-epoch: 73, train loss: 0.2666821759386161
-epoch: 74, train loss: 0.2651688180018946
-epoch: 74, eval loss: 0.7740301787853241, correct: 7739, total: 10000, acc = 0.7738999724388123
-epoch: 75, train loss: 0.2613655937086676
-epoch: 76, train loss: 0.2508497520820382
-epoch: 76, eval loss: 0.7777235865592956, correct: 7777, total: 10000, acc = 0.7777000069618225
-epoch: 77, train loss: 0.2466566213934692
-epoch: 78, train loss: 0.24763428181717076
-epoch: 78, eval loss: 0.7764069586992264, correct: 7765, total: 10000, acc = 0.7764999866485596
-epoch: 79, train loss: 0.24774935457509817
-epoch: 80, train loss: 0.23876388401714796
-epoch: 80, eval loss: 0.7927991092205048, correct: 7740, total: 10000, acc = 0.7739999890327454
-epoch: 81, train loss: 0.23702984618157455
-epoch: 82, train loss: 0.2349560634069836
-epoch: 82, eval loss: 0.794422161579132, correct: 7734, total: 10000, acc = 0.7734000086784363
-epoch: 83, train loss: 0.2302393423220546
-epoch: 84, train loss: 0.22696133203727684
-epoch: 84, eval loss: 0.7939992696046829, correct: 7768, total: 10000, acc = 0.7767999768257141
-epoch: 85, train loss: 0.22804359692273682
-epoch: 86, train loss: 0.21921918088013365
-epoch: 86, eval loss: 0.792869821190834, correct: 7768, total: 10000, acc = 0.7767999768257141
-epoch: 87, train loss: 0.22169384437123524
-epoch: 88, train loss: 0.21990801271089574
-epoch: 88, eval loss: 0.7874982714653015, correct: 7761, total: 10000, acc = 0.7760999798774719
-epoch: 89, train loss: 0.2174029858763685
-epoch: 90, train loss: 0.21091166722405816
-epoch: 90, eval loss: 0.7935442298650741, correct: 7765, total: 10000, acc = 0.7764999866485596
-epoch: 91, train loss: 0.214123952788176
-epoch: 92, train loss: 0.2140680920217455
-epoch: 92, eval loss: 0.7855452030897141, correct: 7787, total: 10000, acc = 0.7786999940872192
-epoch: 93, train loss: 0.2146580269963471
-epoch: 94, train loss: 0.2091474826495672
-epoch: 94, eval loss: 0.7892280638217926, correct: 7783, total: 10000, acc = 0.7782999873161316
-epoch: 95, train loss: 0.21155885015566325
-epoch: 96, train loss: 0.21236088549353413
-epoch: 96, eval loss: 0.7883010923862457, correct: 7785, total: 10000, acc = 0.778499960899353
-epoch: 97, train loss: 0.20943679852583974
-epoch: 98, train loss: 0.20945941495526696
-epoch: 98, eval loss: 0.7873147040605545, correct: 7783, total: 10000, acc = 0.7782999873161316
-epoch: 99, train loss: 0.2085563373012641
-finish training
-TACC: Shutdown complete. Exiting.
diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-4/loss-2D-lr1e-4.jpg b/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-4/loss-2D-lr1e-4.jpg
deleted file mode 100644
index f58df05c5..000000000
Binary files a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/2d-nproc4-lr1e-4/loss-2D-lr1e-4.jpg and /dev/null differ
diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/vanilla-nproc1-lr1e-3/acc-vanilla-lr1e-3.jpg b/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/vanilla-nproc1-lr1e-3/acc-vanilla-lr1e-3.jpg
deleted file mode 100644
index c06130523..000000000
Binary files a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/vanilla-nproc1-lr1e-3/acc-vanilla-lr1e-3.jpg and /dev/null differ
diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/vanilla-nproc1-lr1e-3/alignment.o3476018 b/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/vanilla-nproc1-lr1e-3/alignment.o3476018
deleted file mode 100644
index 6b027b5ab..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/vanilla-nproc1-lr1e-3/alignment.o3476018
+++ /dev/null
@@ -1,165 +0,0 @@
-Wed Sep 1 01:07:01 CDT 2021
-TACC: Starting up job 3476018
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-model is created
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-optimizer is created
-start training
-epoch: 0, train loss: 1.9497510997616514
-epoch: 0, eval loss: 1.754234939813614, correct: 3521, total: 10000, acc = 0.3520999848842621
-epoch: 1, train loss: 1.6049139609142227
-epoch: 2, train loss: 1.3857794501343552
-epoch: 2, eval loss: 1.2831632316112518, correct: 5410, total: 10000, acc = 0.5410000085830688
-epoch: 3, train loss: 1.3016913873808724
-epoch: 4, train loss: 1.2616293649284207
-epoch: 4, eval loss: 1.2658930838108062, correct: 5409, total: 10000, acc = 0.5408999919891357
-epoch: 5, train loss: 1.2320433721250417
-epoch: 6, train loss: 1.181612290898148
-epoch: 6, eval loss: 1.1402096092700957, correct: 5881, total: 10000, acc = 0.5880999565124512
-epoch: 7, train loss: 1.1643818397911228
-epoch: 8, train loss: 1.128499301112428
-epoch: 8, eval loss: 1.0965303361415863, correct: 6053, total: 10000, acc = 0.6053000092506409
-epoch: 9, train loss: 1.114193707704544
-epoch: 10, train loss: 1.0830892950904614
-epoch: 10, eval loss: 1.0390974164009095, correct: 6258, total: 10000, acc = 0.6258000135421753
-epoch: 11, train loss: 1.0508871960396668
-epoch: 12, train loss: 1.0322130365031106
-epoch: 12, eval loss: 0.9689173698425293, correct: 6482, total: 10000, acc = 0.6481999754905701
-epoch: 13, train loss: 1.0006194637746226
-epoch: 14, train loss: 0.9652800906677635
-epoch: 14, eval loss: 0.9150958389043808, correct: 6713, total: 10000, acc = 0.6712999939918518
-epoch: 15, train loss: 0.9430981692002744
-epoch: 16, train loss: 0.9156872307767674
-epoch: 16, eval loss: 0.8703682094812393, correct: 6913, total: 10000, acc = 0.6912999749183655
-epoch: 17, train loss: 0.8822251515729087
-epoch: 18, train loss: 0.8485424190151448
-epoch: 18, eval loss: 0.8234190821647644, correct: 7120, total: 10000, acc = 0.7119999527931213
-epoch: 19, train loss: 0.8285953049757042
-epoch: 20, train loss: 0.8009484337300671
-epoch: 20, eval loss: 0.7808267176151276, correct: 7228, total: 10000, acc = 0.7227999567985535
-epoch: 21, train loss: 0.7774611741912608
-epoch: 22, train loss: 0.7435575358721674
-epoch: 22, eval loss: 0.7523189872503281, correct: 7367, total: 10000, acc = 0.7366999983787537
-epoch: 23, train loss: 0.7315681789602552
-epoch: 24, train loss: 0.70117900627
-epoch: 24, eval loss: 0.6928718358278274, correct: 7580, total: 10000, acc = 0.7579999566078186
-epoch: 25, train loss: 0.677533069435431
-epoch: 26, train loss: 0.6627033298112908
-epoch: 26, eval loss: 0.6921748876571655, correct: 7586, total: 10000, acc = 0.7585999965667725
-epoch: 27, train loss: 0.6410714266251545
-epoch: 28, train loss: 0.6192339707394036
-epoch: 28, eval loss: 0.6416671514511109, correct: 7719, total: 10000, acc = 0.7718999981880188
-epoch: 29, train loss: 0.6093639281331277
-epoch: 30, train loss: 0.582532714520182
-epoch: 30, eval loss: 0.6166591048240662, correct: 7809, total: 10000, acc = 0.7809000015258789
-epoch: 31, train loss: 0.572193189847226
-epoch: 32, train loss: 0.5541256200902316
-epoch: 32, eval loss: 0.5951347410678863, correct: 7922, total: 10000, acc = 0.792199969291687
-epoch: 33, train loss: 0.5345369838938421
-epoch: 34, train loss: 0.5273816007740644
-epoch: 34, eval loss: 0.5837202191352844, correct: 7972, total: 10000, acc = 0.7971999645233154
-epoch: 35, train loss: 0.5059237045292951
-epoch: 36, train loss: 0.48622317095192114
-epoch: 36, eval loss: 0.5698897138237953, correct: 8024, total: 10000, acc = 0.8023999929428101
-epoch: 37, train loss: 0.47362951143663756
-epoch: 38, train loss: 0.46030426907296085
-epoch: 38, eval loss: 0.5610475659370422, correct: 8049, total: 10000, acc = 0.8048999905586243
-epoch: 39, train loss: 0.44165324921510657
-epoch: 40, train loss: 0.4327346086502075
-epoch: 40, eval loss: 0.5642214670777321, correct: 8095, total: 10000, acc = 0.809499979019165
-epoch: 41, train loss: 0.41423581935921494
-epoch: 42, train loss: 0.40917488780556893
-epoch: 42, eval loss: 0.5602998435497284, correct: 8131, total: 10000, acc = 0.8130999803543091
-epoch: 43, train loss: 0.39171184477757437
-epoch: 44, train loss: 0.3744060835059808
-epoch: 44, eval loss: 0.5633655220270157, correct: 8134, total: 10000, acc = 0.8133999705314636
-epoch: 45, train loss: 0.36267226934432983
-epoch: 46, train loss: 0.3420030690577565
-epoch: 46, eval loss: 0.5533872425556183, correct: 8157, total: 10000, acc = 0.8156999945640564
-epoch: 47, train loss: 0.3287143409252167
-epoch: 48, train loss: 0.316296321396925
-epoch: 48, eval loss: 0.5576229721307755, correct: 8209, total: 10000, acc = 0.8208999633789062
-epoch: 49, train loss: 0.3068045072105466
-epoch: 50, train loss: 0.2929732614025778
-epoch: 50, eval loss: 0.5654072970151901, correct: 8227, total: 10000, acc = 0.8226999640464783
-epoch: 51, train loss: 0.2795026940958841
-epoch: 52, train loss: 0.26673941375041493
-epoch: 52, eval loss: 0.5736668109893799, correct: 8227, total: 10000, acc = 0.8226999640464783
-epoch: 53, train loss: 0.2506744866164363
-epoch: 54, train loss: 0.24351145980917677
-epoch: 54, eval loss: 0.5846156671643257, correct: 8204, total: 10000, acc = 0.8203999996185303
-epoch: 55, train loss: 0.2253616195248098
-epoch: 56, train loss: 0.2177750574690955
-epoch: 56, eval loss: 0.5943332687020302, correct: 8246, total: 10000, acc = 0.8245999813079834
-epoch: 57, train loss: 0.20670234989755007
-epoch: 58, train loss: 0.1973607996288611
-epoch: 58, eval loss: 0.6195310011506081, correct: 8245, total: 10000, acc = 0.8244999647140503
-epoch: 59, train loss: 0.19024320448539694
-epoch: 60, train loss: 0.17597664877468225
-epoch: 60, eval loss: 0.6139472931623459, correct: 8294, total: 10000, acc = 0.8294000029563904
-epoch: 61, train loss: 0.1674150490791214
-epoch: 62, train loss: 0.15718420511301684
-epoch: 62, eval loss: 0.6285309329628944, correct: 8261, total: 10000, acc = 0.8260999917984009
-epoch: 63, train loss: 0.1480691913439303
-epoch: 64, train loss: 0.1384550367234921
-epoch: 64, eval loss: 0.6587671056389809, correct: 8263, total: 10000, acc = 0.8262999653816223
-epoch: 65, train loss: 0.13241269834795777
-epoch: 66, train loss: 0.12871786830376605
-epoch: 66, eval loss: 0.6718123883008957, correct: 8303, total: 10000, acc = 0.830299973487854
-epoch: 67, train loss: 0.11577517866176001
-epoch: 68, train loss: 0.11130036151378739
-epoch: 68, eval loss: 0.6887702852487564, correct: 8332, total: 10000, acc = 0.8331999778747559
-epoch: 69, train loss: 0.09883711646710124
-epoch: 70, train loss: 0.09635799735480426
-epoch: 70, eval loss: 0.7159708231687546, correct: 8307, total: 10000, acc = 0.8306999802589417
-epoch: 71, train loss: 0.09449125119313902
-epoch: 72, train loss: 0.08857650914210446
-epoch: 72, eval loss: 0.7160102307796479, correct: 8351, total: 10000, acc = 0.835099995136261
-epoch: 73, train loss: 0.08085554241373831
-epoch: 74, train loss: 0.07873564483407809
-epoch: 74, eval loss: 0.7119918942451477, correct: 8393, total: 10000, acc = 0.8392999768257141
-epoch: 75, train loss: 0.07206312137446841
-epoch: 76, train loss: 0.06772394200824962
-epoch: 76, eval loss: 0.7328802436590195, correct: 8351, total: 10000, acc = 0.835099995136261
-epoch: 77, train loss: 0.061777200397788265
-epoch: 78, train loss: 0.05721901174710722
-epoch: 78, eval loss: 0.7407010316848754, correct: 8385, total: 10000, acc = 0.8384999632835388
-epoch: 79, train loss: 0.056560877406475495
-epoch: 80, train loss: 0.0528045150318316
-epoch: 80, eval loss: 0.7767532706260681, correct: 8354, total: 10000, acc = 0.8353999853134155
-epoch: 81, train loss: 0.050682742870887934
-epoch: 82, train loss: 0.04895328068915678
-epoch: 82, eval loss: 0.7942879348993301, correct: 8368, total: 10000, acc = 0.8367999792098999
-epoch: 83, train loss: 0.04686643050185272
-epoch: 84, train loss: 0.04325723648071289
-epoch: 84, eval loss: 0.7906839996576309, correct: 8356, total: 10000, acc = 0.835599958896637
-epoch: 85, train loss: 0.040166335769605876
-epoch: 86, train loss: 0.039296497894945194
-epoch: 86, eval loss: 0.8033982694149018, correct: 8376, total: 10000, acc = 0.8375999927520752
-epoch: 87, train loss: 0.038185219698566565
-epoch: 88, train loss: 0.03735689769441984
-epoch: 88, eval loss: 0.8039661139249802, correct: 8377, total: 10000, acc = 0.8376999497413635
-epoch: 89, train loss: 0.03383794939145446
-epoch: 90, train loss: 0.03318257091034736
-epoch: 90, eval loss: 0.8097118645906448, correct: 8389, total: 10000, acc = 0.8388999700546265
-epoch: 91, train loss: 0.03290939923109753
-epoch: 92, train loss: 0.030776230903456405
-epoch: 92, eval loss: 0.8237936168909072, correct: 8401, total: 10000, acc = 0.8400999903678894
-epoch: 93, train loss: 0.033349379108344415
-epoch: 94, train loss: 0.031906195783189366
-epoch: 94, eval loss: 0.8250258564949036, correct: 8401, total: 10000, acc = 0.8400999903678894
-epoch: 95, train loss: 0.03031293043334569
-epoch: 96, train loss: 0.029958056238460904
-epoch: 96, eval loss: 0.8200247555971145, correct: 8402, total: 10000, acc = 0.8402000069618225
-epoch: 97, train loss: 0.029532150564981357
-epoch: 98, train loss: 0.029668816346295025
-epoch: 98, eval loss: 0.821219089627266, correct: 8399, total: 10000, acc = 0.8398999571800232
-epoch: 99, train loss: 0.02980129667842875
-finish training
-TACC: Shutdown complete. Exiting.
diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/vanilla-nproc1-lr1e-3/loss-vanilla-lr1e-3.jpg b/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/vanilla-nproc1-lr1e-3/loss-vanilla-lr1e-3.jpg
deleted file mode 100644
index 3f47b07b8..000000000
Binary files a/tests/test_models/test_vision_transformer/test_vit_2d/exp_logs/vanilla-nproc1-lr1e-3/loss-vanilla-lr1e-3.jpg and /dev/null differ
diff --git a/tests/test_models/test_vision_transformer/test_vit_2d/test_vit_2d.py b/tests/test_models/test_vision_transformer/test_vit_2d/test_vit_2d.py
deleted file mode 100644
index 487ba335b..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2d/test_vit_2d.py
+++ /dev/null
@@ -1,84 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from pathlib import Path
-
-import pytest
-import torch.autograd
-
-import colossalai
-from colossalai.builder import build_lr_scheduler
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.nn.layer._parallel_utilities import _gather
-
-CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
-
-
-def eval(engine, test_dataloader):
- engine.eval()
- accumulated_loss = 0
- correct_sum = 0
- total_sum = 0
- num_steps = len(test_dataloader)
- data_iter = iter(test_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
- accumulated_loss += loss.detach().cpu().numpy()
-
- output = _gather(
- output[0],
- ParallelMode.PARALLEL_2D_ROW,
- 1
- )
- output = _gather(
- output,
- ParallelMode.PARALLEL_2D_COL,
- 0,
- )
- output = torch.argmax(output, dim=-1)
- correct = torch.sum(label[0] == output)
- correct_sum += correct
- total_sum += label[0].size(0)
- avg_loss = accumulated_loss / num_steps
- return correct_sum, total_sum, avg_loss
-
-
-def train(engine, train_dataloader, lr_scheduler):
- engine.train()
- accumulated_loss = 0
- num_steps = len(train_dataloader)
- data_iter = iter(train_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
- accumulated_loss += loss.detach().cpu().numpy()
- avg_loss = accumulated_loss / num_steps
- lr_scheduler.step()
- return avg_loss
-
-
-@pytest.mark.dist
-@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
-def test_2d_parallel_vision_transformer():
- # init dist
- engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
- lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
- logger = get_global_dist_logger()
-
- logger.info('start training')
- for epoch in range(gpc.config.num_epochs):
- train_loss = train(engine, train_dataloader, lr_scheduler)
- logger.info(f'epoch {epoch} - train loss: {train_loss}')
-
- if epoch % 2 == 0:
- correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
- logger.info(
- f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
- f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
-
-
-if __name__ == '__main__':
- test_2d_parallel_vision_transformer()
diff --git a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-3.txt b/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-3.txt
deleted file mode 100644
index 54ecbf869..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-3.txt
+++ /dev/null
@@ -1,103 +0,0 @@
-TACC: Starting up job 3498212
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-model is created
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-optimizer is created
-start training
-epoch: 0, train loss: 1.9590576728995965
-epoch: 1, train loss: 1.6275222167676808
-epoch: 1, eval loss: 1.5277319371700286, correct: 4435, total: 10000, acc = 0.44349998235702515
-epoch: 2, train loss: 1.4355541419009774
-epoch: 3, train loss: 1.3253967445723864
-epoch: 3, eval loss: 1.309086227416992, correct: 5283, total: 10000, acc = 0.5282999873161316
-epoch: 4, train loss: 1.2578775298838714
-epoch: 5, train loss: 1.2231916554120121
-epoch: 5, eval loss: 1.1699816286563873, correct: 5695, total: 10000, acc = 0.5694999694824219
-epoch: 6, train loss: 1.1872552669778162
-epoch: 7, train loss: 1.1616783823285783
-epoch: 7, eval loss: 1.069484794139862, correct: 6183, total: 10000, acc = 0.6182999610900879
-epoch: 8, train loss: 1.1155579333402672
-epoch: 9, train loss: 1.0878059365311448
-epoch: 9, eval loss: 1.0522838592529298, correct: 6202, total: 10000, acc = 0.620199978351593
-epoch: 10, train loss: 1.0780728623575093
-epoch: 11, train loss: 1.0522098152004942
-epoch: 11, eval loss: 1.0902862310409547, correct: 6148, total: 10000, acc = 0.614799976348877
-epoch: 12, train loss: 1.0366473337825464
-epoch: 13, train loss: 1.0067467458394108
-epoch: 13, eval loss: 0.9696728616952897, correct: 6531, total: 10000, acc = 0.6530999541282654
-epoch: 14, train loss: 0.9676224273078295
-epoch: 15, train loss: 0.9494374029490412
-epoch: 15, eval loss: 0.9511896312236786, correct: 6646, total: 10000, acc = 0.6645999550819397
-epoch: 16, train loss: 0.9231320935852674
-epoch: 17, train loss: 0.9023846679804276
-epoch: 17, eval loss: 0.8728409796953202, correct: 6866, total: 10000, acc = 0.6865999698638916
-epoch: 18, train loss: 0.8684309854799387
-epoch: 19, train loss: 0.836099565637355
-epoch: 19, eval loss: 0.8208363801240921, correct: 7091, total: 10000, acc = 0.7091000080108643
-epoch: 20, train loss: 0.8285067890371595
-epoch: 21, train loss: 0.7930980793067387
-epoch: 21, eval loss: 0.7793890535831451, correct: 7235, total: 10000, acc = 0.7234999537467957
-epoch: 22, train loss: 0.762698369366782
-epoch: 23, train loss: 0.7376812471418964
-epoch: 23, eval loss: 0.746866625547409, correct: 7340, total: 10000, acc = 0.7339999675750732
-epoch: 24, train loss: 0.7071484223920472
-epoch: 25, train loss: 0.6905171658311572
-epoch: 25, eval loss: 0.6909466415643692, correct: 7526, total: 10000, acc = 0.7525999546051025
-epoch: 26, train loss: 0.6608500091397033
-epoch: 27, train loss: 0.65504517907999
-epoch: 27, eval loss: 0.6612646311521531, correct: 7697, total: 10000, acc = 0.7696999907493591
-epoch: 28, train loss: 0.6234641969203949
-epoch: 29, train loss: 0.6107665622720913
-epoch: 29, eval loss: 0.666494044661522, correct: 7704, total: 10000, acc = 0.7703999876976013
-epoch: 30, train loss: 0.5875011883219894
-epoch: 31, train loss: 0.5739485697478665
-epoch: 31, eval loss: 0.6217960953712464, correct: 7828, total: 10000, acc = 0.7827999591827393
-epoch: 32, train loss: 0.548510205684876
-epoch: 33, train loss: 0.5237194764979032
-epoch: 33, eval loss: 0.6254391580820083, correct: 7842, total: 10000, acc = 0.7841999530792236
-epoch: 34, train loss: 0.5154265892140719
-epoch: 35, train loss: 0.494700480176478
-epoch: 35, eval loss: 0.5981663644313813, correct: 7963, total: 10000, acc = 0.7962999939918518
-epoch: 36, train loss: 0.4785171020395902
-epoch: 37, train loss: 0.46277919259606576
-epoch: 37, eval loss: 0.6061880439519882, correct: 7958, total: 10000, acc = 0.795799970626831
-epoch: 38, train loss: 0.4398626606075131
-epoch: 39, train loss: 0.4206806777083144
-epoch: 39, eval loss: 0.6158866941928863, correct: 7959, total: 10000, acc = 0.7958999872207642
-epoch: 40, train loss: 0.40768756550185536
-epoch: 41, train loss: 0.39494050035671313
-epoch: 41, eval loss: 0.5725498422980309, correct: 8132, total: 10000, acc = 0.8131999969482422
-epoch: 42, train loss: 0.3742571521778496
-epoch: 43, train loss: 0.3583034301290707
-epoch: 43, eval loss: 0.5765605017542839, correct: 8155, total: 10000, acc = 0.8154999613761902
-epoch: 44, train loss: 0.3342630756752832
-epoch: 45, train loss: 0.31316718063792404
-epoch: 45, eval loss: 0.583588008582592, correct: 8199, total: 10000, acc = 0.8198999762535095
-epoch: 46, train loss: 0.30922748148441315
-epoch: 47, train loss: 0.2906164434187266
-epoch: 47, eval loss: 0.5934860140085221, correct: 8143, total: 10000, acc = 0.814300000667572
-epoch: 48, train loss: 0.2741488078419043
-epoch: 49, train loss: 0.2597196321098172
-epoch: 49, eval loss: 0.5978868633508683, correct: 8195, total: 10000, acc = 0.8194999694824219
-epoch: 50, train loss: 0.2440016470393356
-epoch: 51, train loss: 0.2293997729311184
-epoch: 51, eval loss: 0.5915440261363983, correct: 8232, total: 10000, acc = 0.823199987411499
-epoch: 52, train loss: 0.2132072006257213
-epoch: 53, train loss: 0.19785404767917128
-epoch: 53, eval loss: 0.6171442106366157, correct: 8258, total: 10000, acc = 0.8258000016212463
-epoch: 54, train loss: 0.1838149410121295
-epoch: 55, train loss: 0.17691133977199086
-epoch: 55, eval loss: 0.623777586221695, correct: 8275, total: 10000, acc = 0.8274999856948853
-epoch: 56, train loss: 0.16595362697024735
-epoch: 57, train loss: 0.1531825682946614
-epoch: 57, eval loss: 0.6466041743755341, correct: 8243, total: 10000, acc = 0.8242999911308289
-epoch: 58, train loss: 0.14334788979316243
-epoch: 59, train loss: 0.13799503377201605
-epoch: 59, eval loss: 0.6496601745486259, correct: 8249, total: 10000, acc = 0.8248999714851379
-finish training
diff --git a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-3hxmodel.txt b/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-3hxmodel.txt
deleted file mode 100644
index 9bb1bf4bb..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-3hxmodel.txt
+++ /dev/null
@@ -1,196 +0,0 @@
-
-c196-011[rtx](1013)$ bash ./test.sh 1 1 1 0.001
-TACC: Starting up job 3503164
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-USE_VANILLA model
-model is created
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-optimizer is created
-start training
-epoch: 0, train loss: 1.9408839624755236
-epoch: 0, eval loss: 1.7896566271781922, correct: 3488, total: 10000, acc = 0.34880000352859497
-epoch time: 40.82966494560242
-epoch: 1, train loss: 1.6500030257263962
-epoch: 1, eval loss: 1.5464953780174255, correct: 4545, total: 10000, acc = 0.4544999897480011
-epoch time: 40.01254224777222
-epoch: 2, train loss: 1.422887429899099
-epoch: 2, eval loss: 1.37536381483078, correct: 5074, total: 10000, acc = 0.5073999762535095
-epoch time: 40.107905864715576
-epoch: 3, train loss: 1.3217590207956276
-epoch: 3, eval loss: 1.3036327004432677, correct: 5377, total: 10000, acc = 0.5376999974250793
-epoch time: 40.12306189537048
-epoch: 4, train loss: 1.262234352072891
-epoch: 4, eval loss: 1.2568134129047395, correct: 5475, total: 10000, acc = 0.5475000143051147
-epoch time: 40.10755228996277
-epoch: 5, train loss: 1.2381379117771072
-epoch: 5, eval loss: 1.1941023647785187, correct: 5676, total: 10000, acc = 0.5676000118255615
-epoch time: 40.119303464889526
-epoch: 6, train loss: 1.2061052650821453
-epoch: 6, eval loss: 1.1313925206661224, correct: 5938, total: 10000, acc = 0.5938000082969666
-epoch time: 40.07719683647156
-epoch: 7, train loss: 1.1659562563409611
-epoch: 7, eval loss: 1.125486546754837, correct: 5958, total: 10000, acc = 0.59579998254776
-epoch time: 40.1702299118042
-epoch: 8, train loss: 1.1378972846634534
-epoch: 8, eval loss: 1.082760637998581, correct: 6102, total: 10000, acc = 0.6101999878883362
-epoch time: 40.22099733352661
-epoch: 9, train loss: 1.1073276430976635
-epoch: 9, eval loss: 1.1077564001083373, correct: 6038, total: 10000, acc = 0.6037999987602234
-epoch time: 40.1106858253479
-epoch: 10, train loss: 1.087894769347444
-epoch: 10, eval loss: 1.0400531351566316, correct: 6311, total: 10000, acc = 0.6310999989509583
-epoch time: 40.20973324775696
-epoch: 11, train loss: 1.0556547295074075
-epoch: 11, eval loss: 1.0295817345380782, correct: 6359, total: 10000, acc = 0.6358999609947205
-epoch time: 40.23791980743408
-epoch: 12, train loss: 1.0299884901971232
-epoch: 12, eval loss: 1.003737959265709, correct: 6380, total: 10000, acc = 0.6380000114440918
-epoch time: 40.08779859542847
-epoch: 13, train loss: 0.9972386627781148
-epoch: 13, eval loss: 0.9707699298858643, correct: 6499, total: 10000, acc = 0.649899959564209
-epoch time: 40.10878801345825
-epoch: 14, train loss: 0.9784559072280417
-epoch: 14, eval loss: 0.9253897607326508, correct: 6641, total: 10000, acc = 0.6640999913215637
-epoch time: 40.13168978691101
-epoch: 15, train loss: 0.9409253481699495
-epoch: 15, eval loss: 0.9120320588350296, correct: 6759, total: 10000, acc = 0.6758999824523926
-epoch time: 40.162830114364624
-epoch: 16, train loss: 0.925923115136672
-epoch: 16, eval loss: 0.8850776582956315, correct: 6870, total: 10000, acc = 0.6869999766349792
-epoch time: 40.145774602890015
-epoch: 17, train loss: 0.8923340841215484
-epoch: 17, eval loss: 0.8570599347352982, correct: 6950, total: 10000, acc = 0.6949999928474426
-epoch time: 40.18058943748474
-epoch: 18, train loss: 0.8638542884466599
-epoch: 18, eval loss: 0.838410159945488, correct: 6971, total: 10000, acc = 0.6970999836921692
-epoch time: 40.110822439193726
-epoch: 19, train loss: 0.8400422529298432
-epoch: 19, eval loss: 0.8189669162034988, correct: 7097, total: 10000, acc = 0.7096999883651733
-epoch time: 40.066970109939575
-epoch: 20, train loss: 0.8072922752828015
-epoch: 20, eval loss: 0.7772788077592849, correct: 7240, total: 10000, acc = 0.7239999771118164
-epoch time: 40.045086145401
-epoch: 21, train loss: 0.788195074821005
-epoch: 21, eval loss: 0.7793144911527634, correct: 7261, total: 10000, acc = 0.726099967956543
-epoch time: 40.05983781814575
-epoch: 22, train loss: 0.7574447350842612
-epoch: 22, eval loss: 0.7660320281982422, correct: 7272, total: 10000, acc = 0.7271999716758728
-epoch time: 40.11693739891052
-epoch: 23, train loss: 0.7402738150285215
-epoch: 23, eval loss: 0.7264292597770691, correct: 7418, total: 10000, acc = 0.7418000102043152
-epoch time: 40.18724513053894
-epoch: 24, train loss: 0.7125097580102026
-epoch: 24, eval loss: 0.7105035990476608, correct: 7506, total: 10000, acc = 0.7505999803543091
-epoch time: 40.1254940032959
-epoch: 25, train loss: 0.6900304744438249
-epoch: 25, eval loss: 0.6911167114973068, correct: 7562, total: 10000, acc = 0.7561999559402466
-epoch time: 40.103896617889404
-epoch: 26, train loss: 0.6648721482072558
-epoch: 26, eval loss: 0.6780407190322876, correct: 7624, total: 10000, acc = 0.7623999714851379
-epoch time: 40.18161463737488
-epoch: 27, train loss: 0.6446310062797702
-epoch: 27, eval loss: 0.6820667266845704, correct: 7612, total: 10000, acc = 0.761199951171875
-epoch time: 40.19018864631653
-epoch: 28, train loss: 0.6262476389505425
-epoch: 28, eval loss: 0.6506347745656967, correct: 7704, total: 10000, acc = 0.7703999876976013
-epoch time: 40.23526978492737
-epoch: 29, train loss: 0.5968854001590184
-epoch: 29, eval loss: 0.6507940381765366, correct: 7727, total: 10000, acc = 0.7726999521255493
-epoch time: 40.26889181137085
-epoch: 30, train loss: 0.587430303194085
-epoch: 30, eval loss: 0.6333519726991653, correct: 7788, total: 10000, acc = 0.7787999510765076
-epoch time: 40.28285789489746
-epoch: 31, train loss: 0.5701514035463333
-epoch: 31, eval loss: 0.6348810195922852, correct: 7799, total: 10000, acc = 0.7798999547958374
-epoch time: 40.199995040893555
-epoch: 32, train loss: 0.5482188679125845
-epoch: 32, eval loss: 0.6192457497119903, correct: 7833, total: 10000, acc = 0.78329998254776
-epoch time: 40.270729780197144
-epoch: 33, train loss: 0.534268391375639
-epoch: 33, eval loss: 0.6381673783063888, correct: 7790, total: 10000, acc = 0.7789999842643738
-epoch time: 40.36342120170593
-epoch: 34, train loss: 0.5104483384258893
-epoch: 34, eval loss: 0.6173199415206909, correct: 7867, total: 10000, acc = 0.7866999506950378
-epoch time: 40.34266257286072
-epoch: 35, train loss: 0.4968841674984718
-epoch: 35, eval loss: 0.604002220928669, correct: 7916, total: 10000, acc = 0.7915999889373779
-epoch time: 40.39444589614868
-epoch: 36, train loss: 0.4773432207959039
-epoch: 36, eval loss: 0.5884111285209656, correct: 7965, total: 10000, acc = 0.7964999675750732
-epoch time: 40.40647268295288
-epoch: 37, train loss: 0.4621481445370888
-epoch: 37, eval loss: 0.5748852327466011, correct: 8047, total: 10000, acc = 0.8046999573707581
-epoch time: 40.29281520843506
-epoch: 38, train loss: 0.4431859048045411
-epoch: 38, eval loss: 0.5874941781163215, correct: 7995, total: 10000, acc = 0.7994999885559082
-epoch time: 40.40029954910278
-epoch: 39, train loss: 0.4305852785402415
-epoch: 39, eval loss: 0.5991648495197296, correct: 7972, total: 10000, acc = 0.7971999645233154
-epoch time: 40.399904012680054
-epoch: 40, train loss: 0.4092241589512144
-epoch: 40, eval loss: 0.5725525215268135, correct: 8069, total: 10000, acc = 0.8068999648094177
-epoch time: 40.32663059234619
-epoch: 41, train loss: 0.39218547179990887
-epoch: 41, eval loss: 0.5886161357164383, correct: 8068, total: 10000, acc = 0.8068000078201294
-epoch time: 40.32424521446228
-epoch: 42, train loss: 0.3773612398274091
-epoch: 42, eval loss: 0.5762413635849952, correct: 8126, total: 10000, acc = 0.8125999569892883
-epoch time: 40.44430422782898
-epoch: 43, train loss: 0.3593267098981507
-epoch: 43, eval loss: 0.5729024946689606, correct: 8107, total: 10000, acc = 0.810699999332428
-epoch time: 40.488121032714844
-epoch: 44, train loss: 0.3396431426612698
-epoch: 44, eval loss: 0.5944831907749176, correct: 8072, total: 10000, acc = 0.8071999549865723
-epoch time: 40.41803979873657
-epoch: 45, train loss: 0.32412939716358574
-epoch: 45, eval loss: 0.5849291861057282, correct: 8171, total: 10000, acc = 0.8170999884605408
-epoch time: 40.428131341934204
-epoch: 46, train loss: 0.3099915471916296
-epoch: 46, eval loss: 0.5797522723674774, correct: 8121, total: 10000, acc = 0.8120999932289124
-epoch time: 40.623990058898926
-epoch: 47, train loss: 0.29422828676749246
-epoch: 47, eval loss: 0.5898703813552857, correct: 8175, total: 10000, acc = 0.8174999952316284
-epoch time: 40.71224045753479
-epoch: 48, train loss: 0.27581544600579205
-epoch: 48, eval loss: 0.5950756087899208, correct: 8170, total: 10000, acc = 0.8169999718666077
-epoch time: 40.53409385681152
-epoch: 49, train loss: 0.26118586242807157
-epoch: 49, eval loss: 0.5998703584074974, correct: 8213, total: 10000, acc = 0.8212999701499939
-epoch time: 40.564385175704956
-epoch: 50, train loss: 0.2513351797753451
-epoch: 50, eval loss: 0.6011391341686249, correct: 8226, total: 10000, acc = 0.8226000070571899
-epoch time: 40.55033254623413
-epoch: 51, train loss: 0.22965944299892505
-epoch: 51, eval loss: 0.5979882061481476, correct: 8233, total: 10000, acc = 0.8233000040054321
-epoch time: 40.54532980918884
-epoch: 52, train loss: 0.21661002188920975
-epoch: 52, eval loss: 0.6121026620268821, correct: 8220, total: 10000, acc = 0.8219999670982361
-epoch time: 40.649473667144775
-epoch: 53, train loss: 0.20266114950788264
-epoch: 53, eval loss: 0.6016955643892288, correct: 8260, total: 10000, acc = 0.8259999752044678
-epoch time: 40.752054929733276
-epoch: 54, train loss: 0.19287180794136866
-epoch: 54, eval loss: 0.6043265879154205, correct: 8284, total: 10000, acc = 0.8283999562263489
-epoch time: 40.68043255805969
-epoch: 55, train loss: 0.175087109208107
-epoch: 55, eval loss: 0.6146622076630592, correct: 8316, total: 10000, acc = 0.8315999507904053
-epoch time: 40.58446717262268
-epoch: 56, train loss: 0.16749868762432313
-epoch: 56, eval loss: 0.6235148012638092, correct: 8313, total: 10000, acc = 0.8312999606132507
-epoch time: 40.62826180458069
-epoch: 57, train loss: 0.15567801619062618
-epoch: 57, eval loss: 0.6325852945446968, correct: 8308, total: 10000, acc = 0.8307999968528748
-epoch time: 40.72224497795105
-epoch: 58, train loss: 0.1484297229623308
-epoch: 58, eval loss: 0.6329193383455276, correct: 8325, total: 10000, acc = 0.8324999809265137
-epoch time: 40.750558614730835
-epoch: 59, train loss: 0.14238623818572688
-epoch: 59, eval loss: 0.6318104699254036, correct: 8329, total: 10000, acc = 0.8328999876976013
-epoch time: 40.77172636985779
-finish training
\ No newline at end of file
diff --git a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-4.txt b/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-4.txt
deleted file mode 100644
index d7404eea6..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-4.txt
+++ /dev/null
@@ -1,103 +0,0 @@
-TACC: Starting up job 3498663
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-model is created
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-optimizer is created
-start training
-epoch: 0, train loss: 2.095031557034473
-epoch: 1, train loss: 1.8454539605549403
-epoch: 1, eval loss: 1.7768513083457946, correct: 3564, total: 10000, acc = 0.3563999831676483
-epoch: 2, train loss: 1.7044833728245325
-epoch: 3, train loss: 1.5999061124665397
-epoch: 3, eval loss: 1.5574450254440309, correct: 4389, total: 10000, acc = 0.4388999938964844
-epoch: 4, train loss: 1.4929670217085858
-epoch: 5, train loss: 1.401450170546162
-epoch: 5, eval loss: 1.4644017696380616, correct: 4857, total: 10000, acc = 0.48569998145103455
-epoch: 6, train loss: 1.319102376091237
-epoch: 7, train loss: 1.2555806539496597
-epoch: 7, eval loss: 1.2475590467453004, correct: 5486, total: 10000, acc = 0.5485999584197998
-epoch: 8, train loss: 1.1992503173497258
-epoch: 9, train loss: 1.1600336493278036
-epoch: 9, eval loss: 1.1786625683307648, correct: 5834, total: 10000, acc = 0.5834000110626221
-epoch: 10, train loss: 1.1214540807568296
-epoch: 11, train loss: 1.0808329728184913
-epoch: 11, eval loss: 1.096825110912323, correct: 6072, total: 10000, acc = 0.6071999669075012
-epoch: 12, train loss: 1.0521019423494533
-epoch: 13, train loss: 1.0262362957000732
-epoch: 13, eval loss: 1.056444275379181, correct: 6268, total: 10000, acc = 0.626800000667572
-epoch: 14, train loss: 0.9932536555796253
-epoch: 15, train loss: 0.9653559442685575
-epoch: 15, eval loss: 0.9576991081237793, correct: 6582, total: 10000, acc = 0.6581999659538269
-epoch: 16, train loss: 0.9465620943478176
-epoch: 17, train loss: 0.9181081974992946
-epoch: 17, eval loss: 0.9245584070682525, correct: 6747, total: 10000, acc = 0.6746999621391296
-epoch: 18, train loss: 0.8987109752333894
-epoch: 19, train loss: 0.8840238646585115
-epoch: 19, eval loss: 0.8989996433258056, correct: 6787, total: 10000, acc = 0.6786999702453613
-epoch: 20, train loss: 0.8591911811001447
-epoch: 21, train loss: 0.843510093129411
-epoch: 21, eval loss: 0.8595858901739121, correct: 6969, total: 10000, acc = 0.6969000101089478
-epoch: 22, train loss: 0.8306782276046519
-epoch: 23, train loss: 0.8181647640101763
-epoch: 23, eval loss: 0.8600298583507537, correct: 7005, total: 10000, acc = 0.7005000114440918
-epoch: 24, train loss: 0.7964763343334198
-epoch: 25, train loss: 0.7840689718723297
-epoch: 25, eval loss: 0.824479615688324, correct: 7073, total: 10000, acc = 0.7073000073432922
-epoch: 26, train loss: 0.7709570752114666
-epoch: 27, train loss: 0.7591698108887186
-epoch: 27, eval loss: 0.7967212647199631, correct: 7196, total: 10000, acc = 0.7195999622344971
-epoch: 28, train loss: 0.7438001352913526
-epoch: 29, train loss: 0.7341659853653032
-epoch: 29, eval loss: 0.8041222035884857, correct: 7168, total: 10000, acc = 0.7167999744415283
-epoch: 30, train loss: 0.7254330929444761
-epoch: 31, train loss: 0.710246913895315
-epoch: 31, eval loss: 0.7848481118679047, correct: 7287, total: 10000, acc = 0.7286999821662903
-epoch: 32, train loss: 0.6976562008565786
-epoch: 33, train loss: 0.6906438475968887
-epoch: 33, eval loss: 0.7644171923398971, correct: 7370, total: 10000, acc = 0.7369999885559082
-epoch: 34, train loss: 0.6795850834067987
-epoch: 35, train loss: 0.6724951656497254
-epoch: 35, eval loss: 0.7515032321214676, correct: 7368, total: 10000, acc = 0.736799955368042
-epoch: 36, train loss: 0.6527298372619006
-epoch: 37, train loss: 0.651018523440069
-epoch: 37, eval loss: 0.7381327033042908, correct: 7449, total: 10000, acc = 0.7448999881744385
-epoch: 38, train loss: 0.6365304406808348
-epoch: 39, train loss: 0.6372388047831399
-epoch: 39, eval loss: 0.7342826008796692, correct: 7453, total: 10000, acc = 0.7452999949455261
-epoch: 40, train loss: 0.6199644664112403
-epoch: 41, train loss: 0.6101092303894005
-epoch: 41, eval loss: 0.7353240340948105, correct: 7466, total: 10000, acc = 0.7465999722480774
-epoch: 42, train loss: 0.6093496211937496
-epoch: 43, train loss: 0.6019633388032719
-epoch: 43, eval loss: 0.7350291252136231, correct: 7479, total: 10000, acc = 0.7479000091552734
-epoch: 44, train loss: 0.5928211437196148
-epoch: 45, train loss: 0.5840530048827736
-epoch: 45, eval loss: 0.7301350146532058, correct: 7525, total: 10000, acc = 0.7524999976158142
-epoch: 46, train loss: 0.578370426078232
-epoch: 47, train loss: 0.5703256440405943
-epoch: 47, eval loss: 0.7226948082447052, correct: 7526, total: 10000, acc = 0.7525999546051025
-epoch: 48, train loss: 0.5622531275968162
-epoch: 49, train loss: 0.5543749076979501
-epoch: 49, eval loss: 0.7278151929378509, correct: 7536, total: 10000, acc = 0.753600001335144
-epoch: 50, train loss: 0.5494355583677486
-epoch: 51, train loss: 0.5427058047177841
-epoch: 51, eval loss: 0.7180711388587951, correct: 7608, total: 10000, acc = 0.7608000040054321
-epoch: 52, train loss: 0.5323820530760045
-epoch: 53, train loss: 0.5341374232452742
-epoch: 53, eval loss: 0.7136827558279037, correct: 7618, total: 10000, acc = 0.7617999911308289
-epoch: 54, train loss: 0.5295403867351766
-epoch: 55, train loss: 0.5226148692320804
-epoch: 55, eval loss: 0.7158426463603973, correct: 7624, total: 10000, acc = 0.7623999714851379
-epoch: 56, train loss: 0.5206544593888887
-epoch: 57, train loss: 0.5186455438331682
-epoch: 57, eval loss: 0.7141193479299546, correct: 7611, total: 10000, acc = 0.7610999941825867
-epoch: 58, train loss: 0.5130856335163116
-epoch: 59, train loss: 0.5103850683995655
-epoch: 59, eval loss: 0.7077989399433136, correct: 7628, total: 10000, acc = 0.7627999782562256
-finish training
diff --git a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-4hxmodel.txt b/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-4hxmodel.txt
deleted file mode 100644
index 72889a455..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/111log1e-4hxmodel.txt
+++ /dev/null
@@ -1,195 +0,0 @@
-c196-012[rtx](1006)$ bash ./test.sh 1 1 1 0.0001
-TACC: Starting up job 3503177
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-USE_VANILLA model
-model is created
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-optimizer is created
-start training
-epoch: 0, train loss: 2.07912605757616
-epoch: 0, eval loss: 1.9337591707706452, correct: 2845, total: 10000, acc = 0.28450000286102295
-epoch time: 48.79993748664856
-epoch: 1, train loss: 1.8506990890113675
-epoch: 1, eval loss: 1.7832269430160523, correct: 3506, total: 10000, acc = 0.350600004196167
-epoch time: 39.10968255996704
-epoch: 2, train loss: 1.707400695401795
-epoch: 2, eval loss: 1.6983122050762176, correct: 3935, total: 10000, acc = 0.3935000002384186
-epoch time: 39.205119609832764
-epoch: 3, train loss: 1.5925798574272467
-epoch: 3, eval loss: 1.6361137092113496, correct: 4276, total: 10000, acc = 0.4275999963283539
-epoch time: 39.220152378082275
-epoch: 4, train loss: 1.4817699790000916
-epoch: 4, eval loss: 1.4869949519634247, correct: 4706, total: 10000, acc = 0.4705999791622162
-epoch time: 39.297648191452026
-epoch: 5, train loss: 1.3685331247290786
-epoch: 5, eval loss: 1.4110832333564758, correct: 5043, total: 10000, acc = 0.5042999982833862
-epoch time: 39.31484127044678
-epoch: 6, train loss: 1.283743022655954
-epoch: 6, eval loss: 1.317776972055435, correct: 5320, total: 10000, acc = 0.5320000052452087
-epoch time: 39.31891870498657
-epoch: 7, train loss: 1.2292176107971036
-epoch: 7, eval loss: 1.2397323846817017, correct: 5619, total: 10000, acc = 0.5618999600410461
-epoch time: 39.31014013290405
-epoch: 8, train loss: 1.1705418606193698
-epoch: 8, eval loss: 1.2041720151901245, correct: 5696, total: 10000, acc = 0.569599986076355
-epoch time: 39.29190945625305
-epoch: 9, train loss: 1.1253369718181843
-epoch: 9, eval loss: 1.1219275832176208, correct: 6039, total: 10000, acc = 0.6038999557495117
-epoch time: 39.314892053604126
-epoch: 10, train loss: 1.0875617825255102
-epoch: 10, eval loss: 1.1398449420928956, correct: 5921, total: 10000, acc = 0.5920999646186829
-epoch time: 39.29768466949463
-epoch: 11, train loss: 1.055325626110544
-epoch: 11, eval loss: 1.0739773243665696, correct: 6212, total: 10000, acc = 0.6211999654769897
-epoch time: 39.26834416389465
-epoch: 12, train loss: 1.0238730627663282
-epoch: 12, eval loss: 1.0526267528533935, correct: 6244, total: 10000, acc = 0.6243999600410461
-epoch time: 39.30522894859314
-epoch: 13, train loss: 0.9906492087305808
-epoch: 13, eval loss: 1.0342225402593612, correct: 6295, total: 10000, acc = 0.6294999718666077
-epoch time: 39.28985071182251
-epoch: 14, train loss: 0.968360669758855
-epoch: 14, eval loss: 0.9747557610273361, correct: 6498, total: 10000, acc = 0.6498000025749207
-epoch time: 39.33563685417175
-epoch: 15, train loss: 0.9413909072778663
-epoch: 15, eval loss: 0.9359912216663361, correct: 6659, total: 10000, acc = 0.6658999919891357
-epoch time: 39.332377672195435
-epoch: 16, train loss: 0.9215109226654987
-epoch: 16, eval loss: 0.9215879321098328, correct: 6693, total: 10000, acc = 0.6692999601364136
-epoch time: 39.35148882865906
-epoch: 17, train loss: 0.9036085179873875
-epoch: 17, eval loss: 0.8947311192750931, correct: 6787, total: 10000, acc = 0.6786999702453613
-epoch time: 39.31995511054993
-epoch: 18, train loss: 0.8774841433885147
-epoch: 18, eval loss: 0.8880111247301101, correct: 6844, total: 10000, acc = 0.6843999624252319
-epoch time: 39.32100558280945
-epoch: 19, train loss: 0.8607137598553483
-epoch: 19, eval loss: 0.8770220369100571, correct: 6883, total: 10000, acc = 0.6882999539375305
-epoch time: 39.3321533203125
-epoch: 20, train loss: 0.8482279163234088
-epoch: 20, eval loss: 0.8661656975746155, correct: 6926, total: 10000, acc = 0.6926000118255615
-epoch time: 39.319167613983154
-epoch: 21, train loss: 0.8280732814146547
-epoch: 21, eval loss: 0.8369802534580231, correct: 7041, total: 10000, acc = 0.7040999531745911
-epoch time: 39.32543706893921
-epoch: 22, train loss: 0.8162973212952517
-epoch: 22, eval loss: 0.8281545102596283, correct: 7096, total: 10000, acc = 0.7095999717712402
-epoch time: 39.344929695129395
-epoch: 23, train loss: 0.8043988426120914
-epoch: 23, eval loss: 0.8369941651821137, correct: 7070, total: 10000, acc = 0.7069999575614929
-epoch time: 39.342397928237915
-epoch: 24, train loss: 0.788704516328111
-epoch: 24, eval loss: 0.8305304765701294, correct: 7040, total: 10000, acc = 0.7039999961853027
-epoch time: 39.349589347839355
-epoch: 25, train loss: 0.7747861517935383
-epoch: 25, eval loss: 0.8025588423013688, correct: 7164, total: 10000, acc = 0.7163999676704407
-epoch time: 39.35692596435547
-epoch: 26, train loss: 0.7557641073149077
-epoch: 26, eval loss: 0.7929455429315567, correct: 7204, total: 10000, acc = 0.7203999757766724
-epoch time: 39.36091661453247
-epoch: 27, train loss: 0.7422851062550837
-epoch: 27, eval loss: 0.7790816932916641, correct: 7249, total: 10000, acc = 0.7249000072479248
-epoch time: 39.355828046798706
-epoch: 28, train loss: 0.7305653861590794
-epoch: 28, eval loss: 0.7937072366476059, correct: 7204, total: 10000, acc = 0.7203999757766724
-epoch time: 39.3598473072052
-epoch: 29, train loss: 0.719313730998915
-epoch: 29, eval loss: 0.7657937437295914, correct: 7320, total: 10000, acc = 0.7319999933242798
-epoch time: 39.353551626205444
-epoch: 30, train loss: 0.7127084263733455
-epoch: 30, eval loss: 0.7556168884038925, correct: 7341, total: 10000, acc = 0.7340999841690063
-epoch time: 39.37097501754761
-epoch: 31, train loss: 0.7044506967067719
-epoch: 31, eval loss: 0.7438590109348298, correct: 7359, total: 10000, acc = 0.7358999848365784
-epoch time: 39.37364745140076
-epoch: 32, train loss: 0.6920064693810989
-epoch: 32, eval loss: 0.7408553540706635, correct: 7419, total: 10000, acc = 0.7418999671936035
-epoch time: 39.372353076934814
-epoch: 33, train loss: 0.6790882920732304
-epoch: 33, eval loss: 0.7541307628154754, correct: 7332, total: 10000, acc = 0.733199954032898
-epoch time: 39.310251235961914
-epoch: 34, train loss: 0.6666433202977083
-epoch: 34, eval loss: 0.7413494348526001, correct: 7401, total: 10000, acc = 0.7400999665260315
-epoch time: 39.394805908203125
-epoch: 35, train loss: 0.6561720742254841
-epoch: 35, eval loss: 0.7245241671800613, correct: 7483, total: 10000, acc = 0.7482999563217163
-epoch time: 39.34455704689026
-epoch: 36, train loss: 0.6433814526820669
-epoch: 36, eval loss: 0.7294039458036423, correct: 7483, total: 10000, acc = 0.7482999563217163
-epoch time: 39.337549924850464
-epoch: 37, train loss: 0.6366085136423305
-epoch: 37, eval loss: 0.7336494833230972, correct: 7462, total: 10000, acc = 0.7461999654769897
-epoch time: 39.338196754455566
-epoch: 38, train loss: 0.6294400272320728
-epoch: 38, eval loss: 0.719609409570694, correct: 7532, total: 10000, acc = 0.7531999945640564
-epoch time: 39.33430027961731
-epoch: 39, train loss: 0.6179663903859197
-epoch: 39, eval loss: 0.7210630685091018, correct: 7507, total: 10000, acc = 0.7506999969482422
-epoch time: 39.33643341064453
-epoch: 40, train loss: 0.6102935781284254
-epoch: 40, eval loss: 0.6994094282388688, correct: 7569, total: 10000, acc = 0.7568999528884888
-epoch time: 39.38672637939453
-epoch: 41, train loss: 0.5990810029360712
-epoch: 41, eval loss: 0.7133035778999328, correct: 7550, total: 10000, acc = 0.7549999952316284
-epoch time: 39.374757528305054
-epoch: 42, train loss: 0.5964441865074391
-epoch: 42, eval loss: 0.7060712993144989, correct: 7577, total: 10000, acc = 0.7576999664306641
-epoch time: 39.4019033908844
-epoch: 43, train loss: 0.5878602710305428
-epoch: 43, eval loss: 0.7106044471263886, correct: 7580, total: 10000, acc = 0.7579999566078186
-epoch time: 39.408252477645874
-epoch: 44, train loss: 0.5797601254010687
-epoch: 44, eval loss: 0.7093768745660782, correct: 7568, total: 10000, acc = 0.7567999958992004
-epoch time: 39.40289378166199
-epoch: 45, train loss: 0.5684604742089097
-epoch: 45, eval loss: 0.7075642883777619, correct: 7612, total: 10000, acc = 0.761199951171875
-epoch time: 39.35792422294617
-epoch: 46, train loss: 0.5617077308041709
-epoch: 46, eval loss: 0.707081851363182, correct: 7576, total: 10000, acc = 0.7576000094413757
-epoch time: 39.37784481048584
-epoch: 47, train loss: 0.5572127462649832
-epoch: 47, eval loss: 0.7069586098194123, correct: 7606, total: 10000, acc = 0.7605999708175659
-epoch time: 39.33794188499451
-epoch: 48, train loss: 0.5519619742218329
-epoch: 48, eval loss: 0.6923990368843078, correct: 7679, total: 10000, acc = 0.7678999900817871
-epoch time: 39.39500594139099
-epoch: 49, train loss: 0.5454421751961416
-epoch: 49, eval loss: 0.7032370567321777, correct: 7626, total: 10000, acc = 0.7626000046730042
-epoch time: 39.38570594787598
-epoch: 50, train loss: 0.5419908360559114
-epoch: 50, eval loss: 0.6949253618717194, correct: 7669, total: 10000, acc = 0.7669000029563904
-epoch time: 39.334325551986694
-epoch: 51, train loss: 0.5299993215166793
-epoch: 51, eval loss: 0.6966427147388459, correct: 7654, total: 10000, acc = 0.7653999924659729
-epoch time: 39.337984561920166
-epoch: 52, train loss: 0.5282451452649369
-epoch: 52, eval loss: 0.6932955116033555, correct: 7664, total: 10000, acc = 0.7663999795913696
-epoch time: 39.34237813949585
-epoch: 53, train loss: 0.5234840703862054
-epoch: 53, eval loss: 0.6988086104393005, correct: 7654, total: 10000, acc = 0.7653999924659729
-epoch time: 39.364726066589355
-epoch: 54, train loss: 0.5139317989957576
-epoch: 54, eval loss: 0.6950253814458847, correct: 7643, total: 10000, acc = 0.7642999887466431
-epoch time: 39.40451097488403
-epoch: 55, train loss: 0.5158528734226616
-epoch: 55, eval loss: 0.6978882610797882, correct: 7672, total: 10000, acc = 0.7671999931335449
-epoch time: 39.38926696777344
-epoch: 56, train loss: 0.5082419429506574
-epoch: 56, eval loss: 0.6909049898386002, correct: 7692, total: 10000, acc = 0.7691999673843384
-epoch time: 39.42493271827698
-epoch: 57, train loss: 0.5027476120360044
-epoch: 57, eval loss: 0.6897687911987305, correct: 7695, total: 10000, acc = 0.7694999575614929
-epoch time: 39.35954570770264
-epoch: 58, train loss: 0.5053188776483342
-epoch: 58, eval loss: 0.6899506479501725, correct: 7667, total: 10000, acc = 0.7666999697685242
-epoch time: 39.44884634017944
-epoch: 59, train loss: 0.4997740634241883
-epoch: 59, eval loss: 0.687486720085144, correct: 7678, total: 10000, acc = 0.767799973487854
-epoch time: 39.391881465911865
-finish training
diff --git a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/421log1e-3.txt b/tests/test_models/test_vision_transformer/test_vit_2p5d/log/421log1e-3.txt
deleted file mode 100644
index 213cc80fe..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/421log1e-3.txt
+++ /dev/null
@@ -1,115 +0,0 @@
-TACC: Starting up job 3497142
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-model is created
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 2 is bound to device 2
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 3 is bound to device 3
-Files already downloaded and verified
-Files already downloaded and verified
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-optimizer is created
-start training
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 1 is bound to device 1
-Files already downloaded and verified
-Files already downloaded and verified
-epoch: 0, train loss: 1.9320369898056498
-epoch: 1, train loss: 1.6352128605453335
-epoch: 1, eval loss: 1.5123237550258637, correct: 4542, total: 10000, acc = 0.45419999957084656
-epoch: 2, train loss: 1.4457968728882926
-epoch: 3, train loss: 1.3382204977833494
-epoch: 3, eval loss: 1.2539702713489533, correct: 5451, total: 10000, acc = 0.5450999736785889
-epoch: 4, train loss: 1.2739947474732691
-epoch: 5, train loss: 1.2285400483073021
-epoch: 5, eval loss: 1.1386113047599793, correct: 5908, total: 10000, acc = 0.5907999873161316
-epoch: 6, train loss: 1.1903334296479517
-epoch: 7, train loss: 1.1711674235305007
-epoch: 7, eval loss: 1.1258068561553956, correct: 5967, total: 10000, acc = 0.5967000126838684
-epoch: 8, train loss: 1.1419668745021432
-epoch: 9, train loss: 1.1143895728247506
-epoch: 9, eval loss: 1.040754759311676, correct: 6224, total: 10000, acc = 0.6223999857902527
-epoch: 10, train loss: 1.1041023871120141
-epoch: 11, train loss: 1.089750115968743
-epoch: 11, eval loss: 1.0472844064235687, correct: 6265, total: 10000, acc = 0.6265000104904175
-epoch: 12, train loss: 1.064698440687997
-epoch: 13, train loss: 1.038266262229608
-epoch: 13, eval loss: 1.0117274671792984, correct: 6415, total: 10000, acc = 0.6414999961853027
-epoch: 14, train loss: 1.029945282303557
-epoch: 15, train loss: 1.0171620620756734
-epoch: 15, eval loss: 0.9712629705667496, correct: 6519, total: 10000, acc = 0.6518999934196472
-epoch: 16, train loss: 0.9928132119227429
-epoch: 17, train loss: 0.9921575498824217
-epoch: 17, eval loss: 0.9429782271385193, correct: 6641, total: 10000, acc = 0.6640999913215637
-epoch: 18, train loss: 0.9607366293060536
-epoch: 19, train loss: 0.9427766927650997
-epoch: 19, eval loss: 0.9346068739891052, correct: 6623, total: 10000, acc = 0.6622999906539917
-epoch: 20, train loss: 0.9219280481338501
-epoch: 21, train loss: 0.8945026689646195
-epoch: 21, eval loss: 0.8710516095161438, correct: 6909, total: 10000, acc = 0.6908999681472778
-epoch: 22, train loss: 0.8807675826306246
-epoch: 23, train loss: 0.851514169756247
-epoch: 23, eval loss: 0.8239740908145905, correct: 7052, total: 10000, acc = 0.7051999568939209
-epoch: 24, train loss: 0.8388774534877466
-epoch: 25, train loss: 0.8265813291072845
-epoch: 25, eval loss: 0.8102335959672928, correct: 7137, total: 10000, acc = 0.713699996471405
-epoch: 26, train loss: 0.8057564490911912
-epoch: 27, train loss: 0.7816558753957554
-epoch: 27, eval loss: 0.7648743063211441, correct: 7292, total: 10000, acc = 0.729200005531311
-epoch: 28, train loss: 0.766656969883004
-epoch: 29, train loss: 0.7515677390049915
-epoch: 29, eval loss: 0.7517296761274338, correct: 7360, total: 10000, acc = 0.7360000014305115
-epoch: 30, train loss: 0.7300611174836451
-epoch: 31, train loss: 0.7038229193006244
-epoch: 31, eval loss: 0.7385401755571366, correct: 7375, total: 10000, acc = 0.7374999523162842
-epoch: 32, train loss: 0.6928578931458143
-epoch: 33, train loss: 0.672958068093475
-epoch: 33, eval loss: 0.6915913820266724, correct: 7596, total: 10000, acc = 0.7595999836921692
-epoch: 34, train loss: 0.6505378533382805
-epoch: 35, train loss: 0.6292881539889744
-epoch: 35, eval loss: 0.7068031072616577, correct: 7567, total: 10000, acc = 0.7566999793052673
-epoch: 36, train loss: 0.6092992303322773
-epoch: 37, train loss: 0.5922880838720166
-epoch: 37, eval loss: 0.6735526144504547, correct: 7662, total: 10000, acc = 0.7662000060081482
-epoch: 38, train loss: 0.5777627850065425
-epoch: 39, train loss: 0.562178050376931
-epoch: 39, eval loss: 0.6323211371898652, correct: 7799, total: 10000, acc = 0.7798999547958374
-epoch: 40, train loss: 0.5385949274106901
-epoch: 41, train loss: 0.5233490755971597
-epoch: 41, eval loss: 0.6360922038555146, correct: 7806, total: 10000, acc = 0.7805999517440796
-epoch: 42, train loss: 0.50960702373057
-epoch: 43, train loss: 0.48859657985823496
-epoch: 43, eval loss: 0.607847985625267, correct: 7914, total: 10000, acc = 0.7913999557495117
-epoch: 44, train loss: 0.47382923291654006
-epoch: 45, train loss: 0.45052725380780745
-epoch: 45, eval loss: 0.5986941397190094, correct: 8012, total: 10000, acc = 0.8011999726295471
-epoch: 46, train loss: 0.43711013392526277
-epoch: 47, train loss: 0.42507915229213483
-epoch: 47, eval loss: 0.5871582478284836, correct: 8002, total: 10000, acc = 0.8001999855041504
-epoch: 48, train loss: 0.40591827947266246
-epoch: 49, train loss: 0.3911267008100237
-epoch: 49, eval loss: 0.5832945287227631, correct: 8047, total: 10000, acc = 0.8046999573707581
-epoch: 50, train loss: 0.3770884950550235
-epoch: 51, train loss: 0.3587312725733738
-epoch: 51, eval loss: 0.5942261666059494, correct: 8073, total: 10000, acc = 0.8072999715805054
-epoch: 52, train loss: 0.34132662324272856
-epoch: 53, train loss: 0.3267737687850485
-epoch: 53, eval loss: 0.5920912757515907, correct: 8118, total: 10000, acc = 0.8118000030517578
-epoch: 54, train loss: 0.3116904997399875
-epoch: 55, train loss: 0.30321489380938665
-epoch: 55, eval loss: 0.5957943320274353, correct: 8082, total: 10000, acc = 0.8082000017166138
-epoch: 56, train loss: 0.2874147834218278
-epoch: 57, train loss: 0.27991348140093747
-epoch: 57, eval loss: 0.5895262002944947, correct: 8165, total: 10000, acc = 0.8165000081062317
-epoch: 58, train loss: 0.274563160173747
-epoch: 59, train loss: 0.2600744918596988
-epoch: 59, eval loss: 0.5934095367789268, correct: 8150, total: 10000, acc = 0.8149999976158142
-finish training
diff --git a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/421log1e-4.txt b/tests/test_models/test_vision_transformer/test_vit_2p5d/log/421log1e-4.txt
deleted file mode 100644
index 513037271..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/421log1e-4.txt
+++ /dev/null
@@ -1,115 +0,0 @@
-TACC: Starting up job 3498509
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-model is created
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-optimizer is created
-start training
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 2 is bound to device 2
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 3 is bound to device 3
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 1 is bound to device 1
-Files already downloaded and verified
-Files already downloaded and verified
-epoch: 0, train loss: 2.107759721425115
-epoch: 1, train loss: 1.8388929500871776
-epoch: 1, eval loss: 1.7622965753078461, correct: 3535, total: 10000, acc = 0.35349997878074646
-epoch: 2, train loss: 1.7141443588295762
-epoch: 3, train loss: 1.6003259931291853
-epoch: 3, eval loss: 1.608506625890732, correct: 4263, total: 10000, acc = 0.4262999892234802
-epoch: 4, train loss: 1.5016733225511045
-epoch: 5, train loss: 1.4050611877927974
-epoch: 5, eval loss: 1.386299443244934, correct: 4984, total: 10000, acc = 0.4983999729156494
-epoch: 6, train loss: 1.3264902623332278
-epoch: 7, train loss: 1.2681689250225923
-epoch: 7, eval loss: 1.3251740992069245, correct: 5295, total: 10000, acc = 0.5295000076293945
-epoch: 8, train loss: 1.2236176984650748
-epoch: 9, train loss: 1.172800781775494
-epoch: 9, eval loss: 1.1429427027702332, correct: 5966, total: 10000, acc = 0.5965999960899353
-epoch: 10, train loss: 1.1335287532027887
-epoch: 11, train loss: 1.0974334563527788
-epoch: 11, eval loss: 1.1024536848068238, correct: 6107, total: 10000, acc = 0.6107000112533569
-epoch: 12, train loss: 1.0638826300903244
-epoch: 13, train loss: 1.0406859383291127
-epoch: 13, eval loss: 1.0324654281139374, correct: 6282, total: 10000, acc = 0.6281999945640564
-epoch: 14, train loss: 1.0157714376644211
-epoch: 15, train loss: 0.990898135365272
-epoch: 15, eval loss: 0.9790050059556961, correct: 6539, total: 10000, acc = 0.6538999676704407
-epoch: 16, train loss: 0.963820260398242
-epoch: 17, train loss: 0.9404383374720203
-epoch: 17, eval loss: 0.9367435872554779, correct: 6691, total: 10000, acc = 0.6690999865531921
-epoch: 18, train loss: 0.9299906589546982
-epoch: 19, train loss: 0.9038882474510037
-epoch: 19, eval loss: 0.9210823565721512, correct: 6709, total: 10000, acc = 0.6708999872207642
-epoch: 20, train loss: 0.8825302799137271
-epoch: 21, train loss: 0.8686576388320144
-epoch: 21, eval loss: 0.8791542768478393, correct: 6913, total: 10000, acc = 0.6912999749183655
-epoch: 22, train loss: 0.8509396040926174
-epoch: 23, train loss: 0.8375457452268017
-epoch: 23, eval loss: 0.8651147484779358, correct: 6948, total: 10000, acc = 0.6947999596595764
-epoch: 24, train loss: 0.8163802222329744
-epoch: 25, train loss: 0.8068491317787949
-epoch: 25, eval loss: 0.8353333532810211, correct: 7089, total: 10000, acc = 0.708899974822998
-epoch: 26, train loss: 0.7894753631280393
-epoch: 27, train loss: 0.7779296344640304
-epoch: 27, eval loss: 0.8161472469568253, correct: 7143, total: 10000, acc = 0.7142999768257141
-epoch: 28, train loss: 0.763744876092794
-epoch: 29, train loss: 0.7521962505214068
-epoch: 29, eval loss: 0.7903082758188248, correct: 7219, total: 10000, acc = 0.7218999862670898
-epoch: 30, train loss: 0.7443178624522929
-epoch: 31, train loss: 0.7280340212948468
-epoch: 31, eval loss: 0.7877005040645599, correct: 7233, total: 10000, acc = 0.7232999801635742
-epoch: 32, train loss: 0.7196985489251663
-epoch: 33, train loss: 0.7108793039711154
-epoch: 33, eval loss: 0.7838329076766968, correct: 7292, total: 10000, acc = 0.729200005531311
-epoch: 34, train loss: 0.6965019471791326
-epoch: 35, train loss: 0.6875918537986522
-epoch: 35, eval loss: 0.7513678789138794, correct: 7392, total: 10000, acc = 0.7391999959945679
-epoch: 36, train loss: 0.6793362346230721
-epoch: 37, train loss: 0.6741023343436572
-epoch: 37, eval loss: 0.7752945452928544, correct: 7316, total: 10000, acc = 0.7315999865531921
-epoch: 38, train loss: 0.6629589072295597
-epoch: 39, train loss: 0.6507086388918818
-epoch: 39, eval loss: 0.7758691757917404, correct: 7322, total: 10000, acc = 0.7321999669075012
-epoch: 40, train loss: 0.6381483582817778
-epoch: 41, train loss: 0.6374095179596726
-epoch: 41, eval loss: 0.7589699536561966, correct: 7386, total: 10000, acc = 0.738599956035614
-epoch: 42, train loss: 0.6251792050137812
-epoch: 43, train loss: 0.6148473596086308
-epoch: 43, eval loss: 0.7495014071464539, correct: 7478, total: 10000, acc = 0.7477999925613403
-epoch: 44, train loss: 0.6119371378908351
-epoch: 45, train loss: 0.6012086509441843
-epoch: 45, eval loss: 0.725347763299942, correct: 7515, total: 10000, acc = 0.7515000104904175
-epoch: 46, train loss: 0.597867566103838
-epoch: 47, train loss: 0.5913592832429069
-epoch: 47, eval loss: 0.7254288077354432, correct: 7529, total: 10000, acc = 0.7529000043869019
-epoch: 48, train loss: 0.5801522807807339
-epoch: 49, train loss: 0.575563525666996
-epoch: 49, eval loss: 0.7291093468666077, correct: 7533, total: 10000, acc = 0.7532999515533447
-epoch: 50, train loss: 0.573031121674849
-epoch: 51, train loss: 0.5667383588698446
-epoch: 51, eval loss: 0.7240727603435516, correct: 7570, total: 10000, acc = 0.7569999694824219
-epoch: 52, train loss: 0.5578772419569443
-epoch: 53, train loss: 0.5526659309255834
-epoch: 53, eval loss: 0.7226850330829621, correct: 7576, total: 10000, acc = 0.7576000094413757
-epoch: 54, train loss: 0.5473246245968099
-epoch: 55, train loss: 0.5443006860358375
-epoch: 55, eval loss: 0.720612645149231, correct: 7596, total: 10000, acc = 0.7595999836921692
-epoch: 56, train loss: 0.5361242987671677
-epoch: 57, train loss: 0.5323515981435776
-epoch: 57, eval loss: 0.7203025311231613, correct: 7580, total: 10000, acc = 0.7579999566078186
-epoch: 58, train loss: 0.5297852404871766
-epoch: 59, train loss: 0.5288004583241989
-epoch: 59, eval loss: 0.7189624041318894, correct: 7605, total: 10000, acc = 0.7604999542236328
-finish training
diff --git a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/822log1e-3.txt b/tests/test_models/test_vision_transformer/test_vit_2p5d/log/822log1e-3.txt
deleted file mode 100644
index cda0d59ef..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/822log1e-3.txt
+++ /dev/null
@@ -1,131 +0,0 @@
-TACC: Starting up job 3496458
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-model is created
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 3 is bound to device 3
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 2 is bound to device 2
-Files already downloaded and verified
-Files already downloaded and verified
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 7 is bound to device 3
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 6 is bound to device 2
-Files already downloaded and verified
-Files already downloaded and verified
-optimizer is created
-start training
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 4 is bound to device 0
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 5 is bound to device 1
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 1 is bound to device 1
-Files already downloaded and verified
-Files already downloaded and verified
-epoch: 0, train loss: 1.936693473738067
-epoch: 1, train loss: 1.627108974116189
-epoch: 1, eval loss: 1.5279120564460755, correct: 4576, total: 10000, acc = 0.4575999975204468
-epoch: 2, train loss: 1.438910031805233
-epoch: 3, train loss: 1.3184991053172521
-epoch: 3, eval loss: 1.3557079970836639, correct: 5129, total: 10000, acc = 0.5128999948501587
-epoch: 4, train loss: 1.271946340191121
-epoch: 5, train loss: 1.2340542175331894
-epoch: 5, eval loss: 1.207822185754776, correct: 5703, total: 10000, acc = 0.5702999830245972
-epoch: 6, train loss: 1.187913371592152
-epoch: 7, train loss: 1.154962458172623
-epoch: 7, eval loss: 1.0685692846775054, correct: 6100, total: 10000, acc = 0.6100000143051147
-epoch: 8, train loss: 1.1158924905621275
-epoch: 9, train loss: 1.0909727805731249
-epoch: 9, eval loss: 1.0345157146453858, correct: 6328, total: 10000, acc = 0.6327999830245972
-epoch: 10, train loss: 1.0725988399009316
-epoch: 11, train loss: 1.0453423085261364
-epoch: 11, eval loss: 0.9778846323490142, correct: 6543, total: 10000, acc = 0.6542999744415283
-epoch: 12, train loss: 1.0397504823548454
-epoch: 13, train loss: 1.011059400986652
-epoch: 13, eval loss: 0.9668682873249054, correct: 6446, total: 10000, acc = 0.644599974155426
-epoch: 14, train loss: 0.9938353963044225
-epoch: 15, train loss: 0.9691349967401854
-epoch: 15, eval loss: 0.9465512812137604, correct: 6657, total: 10000, acc = 0.6656999588012695
-epoch: 16, train loss: 0.9470896617490419
-epoch: 17, train loss: 0.927201622602891
-epoch: 17, eval loss: 0.8875106543302536, correct: 6837, total: 10000, acc = 0.6836999654769897
-epoch: 18, train loss: 0.8975223132542202
-epoch: 19, train loss: 0.8810242603019792
-epoch: 19, eval loss: 0.8688296616077423, correct: 6832, total: 10000, acc = 0.6832000017166138
-epoch: 20, train loss: 0.8482622784011218
-epoch: 21, train loss: 0.8266285700457436
-epoch: 21, eval loss: 0.7801274597644806, correct: 7205, total: 10000, acc = 0.7204999923706055
-epoch: 22, train loss: 0.8038581859092323
-epoch: 23, train loss: 0.7879118153027126
-epoch: 23, eval loss: 0.7779350578784943, correct: 7203, total: 10000, acc = 0.7202999591827393
-epoch: 24, train loss: 0.7542270896386127
-epoch: 25, train loss: 0.7369782894241567
-epoch: 25, eval loss: 0.7534965008497239, correct: 7362, total: 10000, acc = 0.7361999750137329
-epoch: 26, train loss: 0.7095995545387268
-epoch: 27, train loss: 0.6873777825005201
-epoch: 27, eval loss: 0.7344318777322769, correct: 7381, total: 10000, acc = 0.738099992275238
-epoch: 28, train loss: 0.6713967414534822
-epoch: 29, train loss: 0.650338428969286
-epoch: 29, eval loss: 0.677948921918869, correct: 7653, total: 10000, acc = 0.7652999758720398
-epoch: 30, train loss: 0.6301205882004329
-epoch: 31, train loss: 0.5990057824825754
-epoch: 31, eval loss: 0.6719370454549789, correct: 7643, total: 10000, acc = 0.7642999887466431
-epoch: 32, train loss: 0.590088236696866
-epoch: 33, train loss: 0.5689327443132595
-epoch: 33, eval loss: 0.6191721886396409, correct: 7807, total: 10000, acc = 0.7806999683380127
-epoch: 34, train loss: 0.5426055670392756
-epoch: 35, train loss: 0.5270413601276825
-epoch: 35, eval loss: 0.6150132775306701, correct: 7879, total: 10000, acc = 0.7878999710083008
-epoch: 36, train loss: 0.5215025428606539
-epoch: 37, train loss: 0.4952395400222467
-epoch: 37, eval loss: 0.628344652056694, correct: 7868, total: 10000, acc = 0.786799967288971
-epoch: 38, train loss: 0.47989121687655545
-epoch: 39, train loss: 0.46510300618045186
-epoch: 39, eval loss: 0.5977057978510857, correct: 7944, total: 10000, acc = 0.7943999767303467
-epoch: 40, train loss: 0.4441945254802704
-epoch: 41, train loss: 0.4285763985648447
-epoch: 41, eval loss: 0.5695438250899315, correct: 8023, total: 10000, acc = 0.802299976348877
-epoch: 42, train loss: 0.41337763776584546
-epoch: 43, train loss: 0.3940146170100387
-epoch: 43, eval loss: 0.5688270673155784, correct: 8091, total: 10000, acc = 0.8090999722480774
-epoch: 44, train loss: 0.37741332303504554
-epoch: 45, train loss: 0.36565779605690313
-epoch: 45, eval loss: 0.5831407308578491, correct: 8104, total: 10000, acc = 0.8104000091552734
-epoch: 46, train loss: 0.3468657017362361
-epoch: 47, train loss: 0.32949359198005834
-epoch: 47, eval loss: 0.5751512110233307, correct: 8097, total: 10000, acc = 0.8096999526023865
-epoch: 48, train loss: 0.3140165246262842
-epoch: 49, train loss: 0.29480520498995877
-epoch: 49, eval loss: 0.5712087765336037, correct: 8184, total: 10000, acc = 0.818399965763092
-epoch: 50, train loss: 0.2766021394303867
-epoch: 51, train loss: 0.26527753776433516
-epoch: 51, eval loss: 0.5643855139613152, correct: 8218, total: 10000, acc = 0.8217999935150146
-epoch: 52, train loss: 0.2525861115784061
-epoch: 53, train loss: 0.23714738658496312
-epoch: 53, eval loss: 0.5732526823878288, correct: 8249, total: 10000, acc = 0.8248999714851379
-epoch: 54, train loss: 0.2238179413335664
-epoch: 55, train loss: 0.2119908875652722
-epoch: 55, eval loss: 0.5957901775836945, correct: 8261, total: 10000, acc = 0.8260999917984009
-epoch: 56, train loss: 0.19989302222217833
-epoch: 57, train loss: 0.1875186789096618
-epoch: 57, eval loss: 0.5905491337180138, correct: 8290, total: 10000, acc = 0.8289999961853027
-epoch: 58, train loss: 0.18436841180129926
-epoch: 59, train loss: 0.17459663231762088
-epoch: 59, eval loss: 0.589044263958931, correct: 8313, total: 10000, acc = 0.8312999606132507
-finish training
diff --git a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/822log1e-4.txt b/tests/test_models/test_vision_transformer/test_vit_2p5d/log/822log1e-4.txt
deleted file mode 100644
index 6f69c17cd..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2p5d/log/822log1e-4.txt
+++ /dev/null
@@ -1,131 +0,0 @@
-TACC: Starting up job 3498327
-TACC: Starting parallel tasks...
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 0 is bound to device 0
-distributed environment is initialzied
-model is created
-Files already downloaded and verified
-Files already downloaded and verified
-training and testing dataloaders are created
-loss is created
-optimizer is created
-start training
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 2 is bound to device 2
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 3 is bound to device 3
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 4 is bound to device 0
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 5 is bound to device 1
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 7 is bound to device 3
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 6 is bound to device 2
-Files already downloaded and verified
-Files already downloaded and verified
-warning: variables which starts with __, is a module or class declaration are omitted
-process rank 1 is bound to device 1
-Files already downloaded and verified
-Files already downloaded and verified
-epoch: 0, train loss: 2.1005014667705613
-epoch: 1, train loss: 1.8539113086097094
-epoch: 1, eval loss: 1.7973519027233125, correct: 3362, total: 10000, acc = 0.3361999988555908
-epoch: 2, train loss: 1.7149482040989155
-epoch: 3, train loss: 1.5927067617980801
-epoch: 3, eval loss: 1.5848429083824158, correct: 4344, total: 10000, acc = 0.4343999922275543
-epoch: 4, train loss: 1.4912729798531046
-epoch: 5, train loss: 1.3957378158763962
-epoch: 5, eval loss: 1.4951884388923644, correct: 4841, total: 10000, acc = 0.48409998416900635
-epoch: 6, train loss: 1.3090402642074896
-epoch: 7, train loss: 1.2566283296565621
-epoch: 7, eval loss: 1.2464738070964814, correct: 5562, total: 10000, acc = 0.5561999678611755
-epoch: 8, train loss: 1.2084139476017075
-epoch: 9, train loss: 1.1706127719003327
-epoch: 9, eval loss: 1.162048089504242, correct: 5876, total: 10000, acc = 0.5875999927520752
-epoch: 10, train loss: 1.120817175933293
-epoch: 11, train loss: 1.084984731309268
-epoch: 11, eval loss: 1.0764922022819519, correct: 6155, total: 10000, acc = 0.6154999732971191
-epoch: 12, train loss: 1.0559214432628787
-epoch: 13, train loss: 1.0261321286765896
-epoch: 13, eval loss: 1.0338306188583375, correct: 6334, total: 10000, acc = 0.6333999633789062
-epoch: 14, train loss: 0.992842432187528
-epoch: 15, train loss: 0.9660871296512837
-epoch: 15, eval loss: 1.0059030145406722, correct: 6458, total: 10000, acc = 0.645799994468689
-epoch: 16, train loss: 0.9467733100968965
-epoch: 17, train loss: 0.9243187673237859
-epoch: 17, eval loss: 0.9469569176435471, correct: 6610, total: 10000, acc = 0.6609999537467957
-epoch: 18, train loss: 0.9059403721167116
-epoch: 19, train loss: 0.8819177935318071
-epoch: 19, eval loss: 0.9196836709976196, correct: 6727, total: 10000, acc = 0.6726999878883362
-epoch: 20, train loss: 0.8721987532109631
-epoch: 21, train loss: 0.8469706013494608
-epoch: 21, eval loss: 0.8634845405817032, correct: 6976, total: 10000, acc = 0.6976000070571899
-epoch: 22, train loss: 0.8352831839298716
-epoch: 23, train loss: 0.8124590455269327
-epoch: 23, eval loss: 0.8418784946203232, correct: 7034, total: 10000, acc = 0.7033999562263489
-epoch: 24, train loss: 0.7961219853284408
-epoch: 25, train loss: 0.7883704268202489
-epoch: 25, eval loss: 0.8191130340099335, correct: 7116, total: 10000, acc = 0.7116000056266785
-epoch: 26, train loss: 0.7733409623710477
-epoch: 27, train loss: 0.7561956893424598
-epoch: 27, eval loss: 0.8028618812561035, correct: 7200, total: 10000, acc = 0.7199999690055847
-epoch: 28, train loss: 0.7479740460308231
-epoch: 29, train loss: 0.7343520899208225
-epoch: 29, eval loss: 0.7829996794462204, correct: 7256, total: 10000, acc = 0.725600004196167
-epoch: 30, train loss: 0.7244430549290716
-epoch: 31, train loss: 0.7121965617549663
-epoch: 31, eval loss: 0.765428164601326, correct: 7299, total: 10000, acc = 0.7299000024795532
-epoch: 32, train loss: 0.6988190838268825
-epoch: 33, train loss: 0.6908610359746583
-epoch: 33, eval loss: 0.7602580636739731, correct: 7395, total: 10000, acc = 0.7394999861717224
-epoch: 34, train loss: 0.6785666395206841
-epoch: 35, train loss: 0.6664504153387887
-epoch: 35, eval loss: 0.7671193510293961, correct: 7345, total: 10000, acc = 0.734499990940094
-epoch: 36, train loss: 0.6639333245705585
-epoch: 37, train loss: 0.6509425913800999
-epoch: 37, eval loss: 0.7612941324710846, correct: 7382, total: 10000, acc = 0.7382000088691711
-epoch: 38, train loss: 0.6416311720196082
-epoch: 39, train loss: 0.6312643265237614
-epoch: 39, eval loss: 0.7380059510469437, correct: 7496, total: 10000, acc = 0.7495999932289124
-epoch: 40, train loss: 0.620578939209179
-epoch: 41, train loss: 0.6195461816933691
-epoch: 41, eval loss: 0.7172901630401611, correct: 7550, total: 10000, acc = 0.7549999952316284
-epoch: 42, train loss: 0.6013389248020795
-epoch: 43, train loss: 0.6049416010477104
-epoch: 43, eval loss: 0.7145429253578186, correct: 7569, total: 10000, acc = 0.7568999528884888
-epoch: 44, train loss: 0.5950779300563189
-epoch: 45, train loss: 0.5786038743598121
-epoch: 45, eval loss: 0.7171747118234635, correct: 7569, total: 10000, acc = 0.7568999528884888
-epoch: 46, train loss: 0.5752052083915594
-epoch: 47, train loss: 0.5669339743195748
-epoch: 47, eval loss: 0.7040806382894516, correct: 7601, total: 10000, acc = 0.7601000070571899
-epoch: 48, train loss: 0.5596802952338238
-epoch: 49, train loss: 0.5521421706189915
-epoch: 49, eval loss: 0.7221358746290207, correct: 7592, total: 10000, acc = 0.7591999769210815
-epoch: 50, train loss: 0.5504364164508119
-epoch: 51, train loss: 0.5363630725412952
-epoch: 51, eval loss: 0.710089972615242, correct: 7650, total: 10000, acc = 0.7649999856948853
-epoch: 52, train loss: 0.5382009008709265
-epoch: 53, train loss: 0.5292040118757559
-epoch: 53, eval loss: 0.7044323921203614, correct: 7672, total: 10000, acc = 0.7671999931335449
-epoch: 54, train loss: 0.5289747638970005
-epoch: 55, train loss: 0.5239191630056926
-epoch: 55, eval loss: 0.6983724802732467, correct: 7694, total: 10000, acc = 0.7694000005722046
-epoch: 56, train loss: 0.5177402243930467
-epoch: 57, train loss: 0.5132759012738053
-epoch: 57, eval loss: 0.7066506981849671, correct: 7671, total: 10000, acc = 0.7670999765396118
-epoch: 58, train loss: 0.5119742675095188
-epoch: 59, train loss: 0.5074386891661858
-epoch: 59, eval loss: 0.7012903690338135, correct: 7693, total: 10000, acc = 0.7692999839782715
-finish training
diff --git a/tests/test_models/test_vision_transformer/test_vit_2p5d/test_vit_2p5d.py b/tests/test_models/test_vision_transformer/test_vit_2p5d/test_vit_2p5d.py
deleted file mode 100644
index a8361d2e6..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_2p5d/test_vit_2p5d.py
+++ /dev/null
@@ -1,86 +0,0 @@
-from pathlib import Path
-
-import pytest
-import torch.autograd
-
-import colossalai
-from colossalai.builder import build_lr_scheduler
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.nn.layer._parallel_utilities import _gather
-
-CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
-
-
-def eval(engine, test_dataloader):
- engine.eval()
- accumulated_loss = 0
- correct_sum = 0
- total_sum = 0
- num_steps = len(test_dataloader)
- data_iter = iter(test_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
- accumulated_loss += loss.detach().cpu().numpy()
-
- output = _gather(
- output[0],
- ParallelMode.PARALLEL_2P5D_ROW,
- 1
- )
- output = _gather(
- output,
- ParallelMode.PARALLEL_2P5D_COL,
- 0,
- )
- output = _gather(
- output,
- ParallelMode.PARALLEL_2P5D_DEP,
- 0,
- )
- output = torch.argmax(output, dim=-1)
- correct = torch.sum(label[0] == output)
- correct_sum += correct
- total_sum += label[0].size(0)
- avg_loss = accumulated_loss / num_steps
- return correct_sum, total_sum, avg_loss
-
-
-def train(engine, train_dataloader, lr_scheduler):
- engine.train()
- accumulated_loss = 0
- num_steps = len(train_dataloader)
- data_iter = iter(train_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
- accumulated_loss += loss.detach().cpu().numpy()
- avg_loss = accumulated_loss / num_steps
- lr_scheduler.step()
- return avg_loss
-
-
-@pytest.mark.dist
-@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
-def test_2p5d_parallel_vision_transformer():
- # init dist
- engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
- lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
- logger = get_global_dist_logger()
-
- logger.info('start training')
- for epoch in range(gpc.config.num_epochs):
- train_loss = train(engine, train_dataloader, lr_scheduler)
- logger.info(f'epoch {epoch} - train loss: {train_loss}')
-
- if epoch % 2 == 0:
- correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
- logger.info(
- f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
- f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
-
-
-if __name__ == '__main__':
- test_2p5d_parallel_vision_transformer()
diff --git a/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py b/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py
deleted file mode 100644
index 7bee2c78b..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py
+++ /dev/null
@@ -1,105 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-import time
-from pathlib import Path
-
-import torch
-from tqdm import tqdm
-
-import colossalai
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.trainer import Trainer
-from colossalai.trainer.metric import Accuracy3D
-from colossalai.utils import print_rank_0
-
-CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_3d.py')
-
-
-def _train_epoch(epoch, engine):
- logger = get_global_dist_logger()
- print_rank_0('[Epoch %d] training start' % (epoch), logger)
- engine.train()
-
- train_loss = 0
- batch_cnt = 0
- num_samples = 0
- now = time.time()
- epoch_start = now
- progress = range(engine._schedule.num_steps)
- if gpc.get_global_rank() == 0:
- progress = tqdm(progress, desc='[Epoch %d]' % epoch, miniters=1)
- for step in progress:
- cur_lr = engine.get_lr()
-
- _, targets, loss = engine.step()
-
- batch_size = targets[0].size(0)
- train_loss += loss.item()
- num_samples += batch_size
- batch_cnt += 1
-
- batch_time = time.time() - now
- now = time.time()
- if gpc.get_global_rank() == 0:
- print_features = dict(lr='%g' % cur_lr,
- loss='%.3f' % (train_loss / (step + 1)),
- throughput='%.3f (images/sec)' %
- (batch_size / (batch_time + 1e-12)))
- progress.set_postfix(**print_features)
-
- epoch_end = time.time()
- epoch_loss = train_loss / batch_cnt
- epoch_throughput = num_samples / (epoch_end - epoch_start + 1e-12)
- print_rank_0(
- '[Epoch %d] Loss: %.3f | Throughput: %.3f (samples/sec)' %
- (epoch, epoch_loss, epoch_throughput), logger)
-
-
-def _eval(epoch, engine):
- logger = get_global_dist_logger()
- engine.eval()
-
- eval_loss = 0
- acc = Accuracy3D(True, ParallelMode.PARALLEL_3D_OUTPUT,
- ParallelMode.PARALLEL_3D_WEIGHT)
- total = 0
- with torch.no_grad():
- for _ in range(engine._schedule.num_steps):
- outputs, targets, loss = engine.step()
- if isinstance(outputs, (list, tuple)):
- outputs = outputs[0]
- if isinstance(targets, (list, tuple)):
- targets = targets[0]
- eval_loss += loss.item()
- acc.update(outputs, targets)
- total += targets.size(0)
-
- print_rank_0(
- '[Epoch %d] Evaluation loss: %.3f | Acc: %.3f%%' %
- (epoch, eval_loss / engine._schedule.num_steps,
- acc.get_accumulated_value() * 100), logger)
-
-
-def train():
- # init dist
- engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
- logger = get_global_dist_logger()
-
- logger.info("Engine is built", ranks=[0])
-
- trainer = Trainer(engine=engine, verbose=True)
- logger.info("Trainer is built", ranks=[0])
-
- logger.info("Train start", ranks=[0])
- trainer.fit(train_dataloader=train_dataloader,
- test_dataloader=test_dataloader,
- epochs=gpc.config.num_epochs,
- hooks_cfg=gpc.config.hooks,
- display_progress=True,
- test_interval=1)
-
-
-if __name__ == '__main__':
- train()
diff --git a/tests/test_models/test_vision_transformer/test_vit_vanilla.py b/tests/test_models/test_vision_transformer/test_vit_vanilla.py
deleted file mode 100644
index f52161748..000000000
--- a/tests/test_models/test_vision_transformer/test_vit_vanilla.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from pathlib import Path
-
-import pytest
-import torch
-
-from colossalai.builder import build_model
-from colossalai.context import Config
-
-CONFIG_PATH = Path(__file__).parent.joinpath('configs/vanilla_vit.py')
-
-
-@pytest.mark.cpu
-def test_with_vanilla_vit_config():
- config = Config.from_file(CONFIG_PATH)
- model = build_model(config.model)
- model.build_from_cfg()
-
- img = torch.randn(1, 3, config.IMG_SIZE, config.IMG_SIZE)
- out = model(img)
- loss = out.mean()
- loss.backward()
-
-
-if __name__ == '__main__':
- test_with_vanilla_vit_config()
diff --git a/tests/test_trainer/configs/test_trainer_resnet.py b/tests/test_trainer/configs/test_trainer_resnet.py
index ff48d4e6c..bd69dc475 100644
--- a/tests/test_trainer/configs/test_trainer_resnet.py
+++ b/tests/test_trainer/configs/test_trainer_resnet.py
@@ -1,77 +1,6 @@
import os
from pathlib import Path
-BATCH_SIZE = 128
-IMG_SIZE = 32
-num_epochs = 200
-
-# resnet 50
-model = dict(
- type='VanillaResNet',
- block_type='ResNetBottleneck',
- layers=[3, 4, 6, 3],
- num_cls=10
-)
-
-train_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='RandomCrop', size=IMG_SIZE, padding=4),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True
- )
-)
-
-test_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]
- ),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True
- )
-)
-
-optimizer = dict(
- type='SGD',
- lr=0.2,
- momentum=0.9,
- weight_decay=5e-4
-)
-
-loss = dict(
- type='CrossEntropyLoss',
-)
-
-parallel = dict(
- pipeline=dict(size=1),
- tensor=dict(size=1, mode=None),
-)
hooks = [
dict(type='LogMetricByEpochHook'),
@@ -88,4 +17,3 @@ hooks = [
),
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
]
-
diff --git a/tests/test_trainer/test.sh b/tests/test_trainer/test.sh
index 65c4fc4bd..fa0ae78d5 100644
--- a/tests/test_trainer/test.sh
+++ b/tests/test_trainer/test.sh
@@ -1,5 +1,4 @@
#!/usr/bin/env sh
test_file=$1
-config_file=$2
-python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500 --config $config_file
+python $test_file --rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
diff --git a/tests/test_engine/test_pipeline/debug_schedule.py b/tests/test_trainer/test_pipeline/debug_schedule.py
similarity index 100%
rename from tests/test_engine/test_pipeline/debug_schedule.py
rename to tests/test_trainer/test_pipeline/debug_schedule.py
diff --git a/tests/test_engine/test_pipeline/test_p2p.py b/tests/test_trainer/test_pipeline/test_p2p.py
similarity index 97%
rename from tests/test_engine/test_pipeline/test_p2p.py
rename to tests/test_trainer/test_pipeline/test_p2p.py
index aa1a0f5e1..39cfa1003 100644
--- a/tests/test_engine/test_pipeline/test_p2p.py
+++ b/tests/test_trainer/test_pipeline/test_p2p.py
@@ -13,7 +13,7 @@ from colossalai.communication import (recv_backward, recv_forward,
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist, parse_args
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
BATCH_SIZE = 32
@@ -65,7 +65,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger):
tensor = send_backward_recv_forward(output_grad, output_tensor.shape)
logger.info(
'Rank {} sent backward received forward. Correct tensor: {}'.
- format(rank, check_equal(tensor, output_tensor)))
+ format(rank, check_equal(tensor, output_tensor)))
if not gpc.is_last_rank(ParallelMode.PIPELINE):
grad = send_forward_recv_backward(output_tensor, output_grad.shape)
logger.info(
@@ -128,7 +128,7 @@ def test_main():
world_size = args.world_size
init_dist(CONFIG)
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
rank = gpc.get_global_rank()
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
up_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_PREV)
diff --git a/tests/test_engine/test_pipeline/test_partition.py b/tests/test_trainer/test_pipeline/test_partition.py
similarity index 91%
rename from tests/test_engine/test_pipeline/test_partition.py
rename to tests/test_trainer/test_pipeline/test_partition.py
index 65c108162..d3c811657 100644
--- a/tests/test_engine/test_pipeline/test_partition.py
+++ b/tests/test_trainer/test_pipeline/test_partition.py
@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
from colossalai.builder import build_dataset, ModelInitializer
from colossalai.core import global_context
from colossalai.initialize import init_dist
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@@ -17,7 +17,7 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.dist
def test_partition():
init_dist(CONFIG_PATH)
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
logger.info('finished initialization')
# build model
diff --git a/tests/test_engine/test_pipeline/test_schedule.py b/tests/test_trainer/test_pipeline/test_schedule.py
similarity index 92%
rename from tests/test_engine/test_pipeline/test_schedule.py
rename to tests/test_trainer/test_pipeline/test_schedule.py
index 9125fb3ee..7e2f32017 100644
--- a/tests/test_engine/test_pipeline/test_schedule.py
+++ b/tests/test_trainer/test_pipeline/test_schedule.py
@@ -8,7 +8,7 @@ import pytest
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import initialize
-from colossalai.logging import get_global_dist_logger
+from colossalai.logging import get_dist_logger
NUM_BATCH = 128
@@ -24,7 +24,7 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.dist
def test_schedule():
engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
- logger = get_global_dist_logger()
+ logger = get_dist_logger()
model = engine.model
optimizer = engine.optimizer
diff --git a/tests/test_trainer/test_trainer.py b/tests/test_trainer/test_trainer.py
deleted file mode 100644
index 6a7681d00..000000000
--- a/tests/test_trainer/test_trainer.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import colossalai
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_global_dist_logger
-from colossalai.trainer import Trainer
-
-
-def test_trainer():
- engine, train_dataloader, test_dataloader = colossalai.initialize()
- logger = get_global_dist_logger()
-
- logger.info("engine is built", ranks=[0])
-
- trainer = Trainer(engine=engine,
- verbose=True)
- logger.info("trainer is built", ranks=[0])
-
- logger.info("start training", ranks=[0])
- trainer.fit(
- train_dataloader=train_dataloader,
- test_dataloader=test_dataloader,
- hooks_cfg=gpc.config.hooks,
- epochs=gpc.config.num_epochs,
- display_progress=False,
- test_interval=5
- )
-
-
-if __name__ == '__main__':
- test_trainer()
diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py
new file mode 100644
index 000000000..170f38087
--- /dev/null
+++ b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py
@@ -0,0 +1,113 @@
+import colossalai
+import os
+from colossalai.amp.amp_type import AMP_TYPE
+import torch.nn as nn
+
+from pathlib import Path
+from torchvision import transforms
+from torch.optim import Adam
+from colossalai.initialize import get_default_parser
+from colossalai.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.trainer import Trainer
+from colossalai.utils import get_dataloader
+from torchvision.models import resnet18
+from torchvision.datasets import CIFAR10
+
+BATCH_SIZE = 128
+IMG_SIZE = 32
+NUM_EPOCHS = 200
+
+CONFIG = dict(
+ # Config
+ fp16=dict(
+ mode=AMP_TYPE.TORCH
+ )
+)
+
+
+def test_trainer():
+ parser = get_default_parser()
+ args = parser.parse_args()
+ colossalai.launch(
+ config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend
+ )
+
+ # build model
+ model = resnet18(num_classes=10)
+
+ # build dataloaders
+ train_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+
+ test_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ train=False,
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ drop_last=True)
+
+ test_dataloader = get_dataloader(dataset=test_dataset,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ drop_last=True)
+
+ # build optimizer
+ optimizer = Adam(model.parameters(), lr=0.001)
+ criterion = nn.CrossEntropyLoss()
+
+ engine, train_dataloader, *args = colossalai.initialize(
+ model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader
+ )
+
+ logger = get_dist_logger()
+ logger.info("engine is built", ranks=[0])
+
+ trainer = Trainer(engine=engine,
+ logger=logger)
+ logger.info("trainer is built", ranks=[0])
+
+ logger.info("start training", ranks=[0])
+ trainer.fit(
+ train_dataloader=train_dataloader,
+ test_dataloader=test_dataloader,
+ epochs=NUM_EPOCHS,
+ max_steps=100,
+ display_progress=True,
+ test_interval=5
+ )
+
+
+if __name__ == '__main__':
+ test_trainer()
diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_trainer/test_trainer_with_pipe_schedule.py
new file mode 100644
index 000000000..63a22f6ec
--- /dev/null
+++ b/tests/test_trainer/test_trainer_with_pipe_schedule.py
@@ -0,0 +1,146 @@
+import colossalai
+import os
+import torch
+from colossalai.amp.amp_type import AMP_TYPE
+from colossalai.context.parallel_mode import ParallelMode
+import torch.nn as nn
+
+from pathlib import Path
+from torchvision import transforms
+from torch.optim import Adam
+from colossalai.initialize import get_default_parser
+from colossalai.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.trainer import Trainer
+from colossalai.utils import get_dataloader
+from colossalai.engine.schedule import PipelineSchedule
+from torchvision.models import resnet18
+from torchvision.datasets import CIFAR10
+
+BATCH_SIZE = 32
+IMG_SIZE = 32
+NUM_EPOCHS = 200
+
+CONFIG = dict(
+ parallel=dict(
+ pipeline=2,
+ ),
+ # Config
+ fp16=dict(
+ mode=AMP_TYPE.TORCH
+ )
+)
+
+
+def test_trainer():
+ parser = get_default_parser()
+ args = parser.parse_args()
+ colossalai.launch(
+ config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend
+ )
+
+ # build model
+ model = resnet18(num_classes=10)
+
+ if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
+ model = nn.Sequential(
+ model.conv1,
+ model.bn1,
+ model.relu,
+ model.maxpool,
+ model.layer1,
+ model.layer2
+ )
+ elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
+ from functools import partial
+
+ class Flatten(nn.Module):
+
+ def forward(self, x):
+ return torch.flatten(x, 1)
+
+ model = nn.Sequential(
+ model.layer3,
+ model.layer4,
+ model.avgpool,
+ Flatten(),
+ model.fc
+ )
+
+ # build dataloaders
+ train_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+
+ test_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ train=False,
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ drop_last=True)
+
+ test_dataloader = get_dataloader(dataset=test_dataset,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ drop_last=True)
+
+ # build optimizer
+ optimizer = Adam(model.parameters(), lr=0.001)
+ criterion = nn.CrossEntropyLoss()
+
+ engine, train_dataloader, *args = colossalai.initialize(
+ model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader
+ )
+
+ logger = get_dist_logger()
+ logger.info("engine is built", ranks=[0])
+ pipe_schedule = PipelineSchedule(num_microbatches=4)
+ trainer = Trainer(engine=engine,
+ schedule=pipe_schedule,
+ logger=logger)
+ logger.info("trainer is built", ranks=[0])
+
+ logger.info("start training", ranks=[0])
+
+ trainer.fit(
+ train_dataloader=train_dataloader,
+ test_dataloader=test_dataloader,
+ epochs=NUM_EPOCHS,
+ max_steps=100,
+ display_progress=True,
+ test_interval=5
+ )
+
+
+if __name__ == '__main__':
+ test_trainer()
diff --git a/tests/test_utils/test_gradient_accumluation.py b/tests/test_utils/test_gradient_accumluation.py
new file mode 100644
index 000000000..4f7ccd09b
--- /dev/null
+++ b/tests/test_utils/test_gradient_accumluation.py
@@ -0,0 +1,117 @@
+import colossalai
+import os
+import pytest
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+from functools import partial
+from pathlib import Path
+from torchvision import transforms
+from torch.optim import Adam
+from colossalai.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.utils import report_memory_usage, get_dataloader
+from colossalai.initialize import get_default_parser
+from torchvision.models import resnet18
+from torchvision.datasets import CIFAR10
+
+
+# Config
+BATCH_SIZE = 16
+IMG_SIZE = 224
+NUM_CLASSES = 10
+
+CONFIG = dict(
+ parallel=dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=1, mode=None)
+ ),
+ clip_grad_norm=1.0,
+ gradient_accumulation=4
+)
+
+
+def run_no_pipeline(rank, world_size):
+
+ # init dist env
+ colossalai.launch(
+ config=CONFIG,
+ rank=rank,
+ world_size=world_size,
+ host='localhost',
+ port=29500,
+ backend='nccl'
+ )
+
+ # build model
+ model = resnet18(num_classes=10)
+
+ # build dataloaders
+ train_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=BATCH_SIZE,
+ pin_memory=True,
+ drop_last=True)
+
+ # build optimizer
+ optimizer = Adam(model.parameters(), lr=0.001)
+ criterion = nn.CrossEntropyLoss()
+
+ engine, train_dataloader, *args = colossalai.initialize(
+ model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader
+ )
+ logger = get_dist_logger()
+ rank = torch.distributed.get_rank()
+ param_track = []
+ grad_track = []
+ next(model.parameters()).retain_grad()
+
+ engine.train()
+ step = 0
+ for img, label in train_dataloader:
+ engine.zero_grad()
+ img = img.cuda()
+ label = label.cuda()
+ output = engine(img)
+ loss = engine.criterion(output, label)
+ engine.backward(loss)
+ engine.step()
+
+ # check
+ param_track.append(next(model.parameters())[0].clone())
+ grad_track.append(next(model.parameters()).grad[0].clone())
+ step += 1
+ if step == CONFIG['gradient_accumulation']:
+ break
+
+ assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations'
+ assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \
+ 'param should be the same in the first few iterations and only changed in the last iteration'
+
+ gpc.destroy()
+
+
+@pytest.mark.skip("This test should be invoked using the test.sh provided")
+@pytest.mark.dist
+def test_engine():
+ func = partial(run_no_pipeline, world_size=4)
+ mp.spawn(func, nprocs=4)
+
+
+if __name__ == '__main__':
+ test_engine()
diff --git a/tests/test_zero_data_parallel/config.py b/tests/test_zero_data_parallel/config.py
index 3e9d081d1..8e263505b 100644
--- a/tests/test_zero_data_parallel/config.py
+++ b/tests/test_zero_data_parallel/config.py
@@ -2,90 +2,3 @@
# -*- encoding: utf-8 -*-
import os
from pathlib import Path
-
-BATCH_SIZE = 128
-IMG_SIZE = 224
-NUM_CLS = 1000
-
-# resnet 18
-model = dict(
- type='VanillaResNet',
- block_type='ResNetBottleneck',
- layers=[3, 4, 6, 3],
- num_cls=NUM_CLS
-)
-
-train_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- transform_pipeline=[
- dict(type='RandomResizedCrop', size=IMG_SIZE),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
- ]
- ),
- dataloader=dict(
- batch_size=64,
- pin_memory=True,
- num_workers=4,
- sampler=dict(
- type='DataParallelSampler',
- shuffle=True,
- )
- )
-)
-
-test_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=(IMG_SIZE, IMG_SIZE)),
- dict(type='ToTensor'),
- dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- )
-)
-
-dist_initializer = [
- dict(type='DataParallelInitializer'),
-]
-
-parallelization = dict(
- pipeline=1,
- tensor=1,
- sequence=-1
-)
-
-optimizer = dict(
- type='Adam',
- lr=0.01
-)
-
-loss = dict(
- type='CrossEntropyLoss'
-)
-
-trainer = dict(
- max_epochs=5,
- max_iters=1000
-)
-
-amp = dict(
- fp16=None,
-)
-
-level = 2
-
-parallel = dict(
- pipeline=dict(size=1),
- tensor=dict(size=1, mode=None)
-)
diff --git a/tests/test_zero_data_parallel/test_zero.py b/tests/test_zero_data_parallel/test_zero.py
index e47ca61a5..6331a9a2b 100644
--- a/tests/test_zero_data_parallel/test_zero.py
+++ b/tests/test_zero_data_parallel/test_zero.py
@@ -1,146 +1,118 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-import os.path as osp
-
+import os
import pytest
import torch
-from torch.utils.data import DataLoader
+
+from pathlib import Path
import colossalai
-from colossalai.builder import build_dataset, build_loss, build_data_sampler, build_model
-from colossalai.core import global_context
-from colossalai.engine.gradient_handler import DataParallelGradientHandler
-from colossalai.nn.optimizer import ZeroRedundancyOptimizer_Level_1, ZeroRedundancyOptimizer_Level_3, \
- ZeroRedundancyOptimizer_Level_2
-from colossalai.utils import print_rank_0
+from colossalai.initialize import get_default_parser
+from colossalai.core import global_context as gpc
+from colossalai.utils import get_dataloader
+from torchvision import transforms
+from torchvision.models import resnet18
+from torchvision.datasets import CIFAR10
-DIR_PATH = osp.dirname(osp.abspath(__file__))
-CONFIG_PATH = osp.join(DIR_PATH, 'config.py')
+BATCH_SIZE = 128
+IMG_SIZE = 224
+NUM_CLS = 1000
+
+CONFIG = dict(
+ fp16=dict(
+ mode=None,
+ ),
+ zero=dict(
+ # ==============
+ # level 2 config
+ # ==============
+ # level=2,
+ # cpu_offload=True,
+ # verbose=False,
+
+ # ==============
+ # level 3 config
+ # ==============
+ level=3,
+ verbose=False,
+ offload_optimizer_config=dict(
+ device='cpu',
+ pin_memory=True,
+ buffer_count=5,
+ fast_init=False
+ ),
+ offload_param_config=dict(
+ device='cpu',
+ pin_memory=True,
+ buffer_count=5,
+ buffer_size=1e8,
+ max_in_cpu=1e9
+ )
+ ),
+ parallel=dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=1, mode=None)
+ )
+)
def run_dist():
- colossalai.init_dist(CONFIG_PATH)
+ parser = get_default_parser()
+ args = parser.parse_args()
- # build resnet model
- model = build_model(global_context.config.model)
- model.build_from_cfg()
- model = model.cuda()
+ colossalai.launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend)
- level = global_context.config.level
+ # build model
+ model = resnet18(num_classes=10)
- if level > 1:
- model = model.half()
-
- # test init cuda memory
- _ = torch.rand(1).cuda()
- torch.cuda.synchronize()
- max_alloc = torch.cuda.max_memory_allocated()
- max_reserved = torch.cuda.max_memory_reserved()
- print(f'before run: max_allocation = {max_alloc}, max_reserved = {max_reserved}')
-
- # build dataloader
- train_dataset = build_dataset(global_context.config.train_data.dataset)
-
- sampler_cfg = global_context.config.train_data.dataloader.pop('sampler', None)
- if sampler_cfg is None:
- train_dataloader = DataLoader(dataset=train_dataset, **global_context.config.train_data.dataloader)
- else:
- sampler = build_data_sampler(sampler_cfg, train_dataset)
- train_dataloader = DataLoader(dataset=train_dataset, sampler=sampler,
- **global_context.config.train_data.dataloader)
-
- test_dataset = build_dataset(global_context.config.test_data.dataset)
- test_dataloader = DataLoader(dataset=test_dataset, **global_context.config.test_data.dataloader)
+ # build dataloader# build dataloaders
+ train_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ drop_last=True)
# build optimizer and loss
# optimizer = build_optimizer(global_context.config.optimizer, model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
- if level == 1:
- zero_optim = ZeroRedundancyOptimizer_Level_1(init_optimizer=optimizer, verbose=False)
- elif level == 2:
- zero_optim = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, cpu_offload=True, verbose=False)
- elif level == 3:
- zero_optim = ZeroRedundancyOptimizer_Level_3(init_optimizer=optimizer,
- module=model,
- verbose=False,
- offload_optimizer_config=dict(
- device='cpu',
- pin_memory=True,
- buffer_count=5,
- fast_init=False
- ),
- offload_param_config=dict(
- device='cpu',
- pin_memory=True,
- buffer_count=5,
- buffer_size=1e8,
- max_in_cpu=1e9
- )
- )
+ criterion = torch.nn.CrossEntropyLoss()
- loss_fn = build_loss(global_context.config.loss)
- gradient_handler = DataParallelGradientHandler(model, zero_optim)
+ engine, train_dataloader, *args = colossalai.initialize(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader)
# train
- for epoch in range(100):
- model.train()
+ model.train()
+ for idx, (data, label) in enumerate(train_dataloader):
+ engine.zero_grad()
+ data = data.cuda()
+ label = label.cuda()
- # train
- avg_train_loss = 0
- train_iter = 0
+ output = engine(data)
+ loss = engine.criterion(output, label)
- for idx, (data, label) in enumerate(train_dataloader):
- # model = model.half()
- data = data[0].cuda()
- label = label[0].cuda()
-
- if level > 1:
- data = data.half()
-
- output = model(data)
- loss = loss_fn(output[0], label)
-
- if level > 1:
- zero_optim.backward(loss)
- zero_optim.overlapping_partition_gradients_reduce_epilogue()
- else:
- loss.backward()
- gradient_handler.handle_gradient()
-
- zero_optim.step()
- zero_optim.zero_grad()
-
- avg_train_loss += loss.detach().cpu().numpy()
- train_iter += 1
-
- print_rank_0(f'epoch: {epoch}, train loss: {avg_train_loss / train_iter}')
-
- if epoch % 2 == 0:
- model.eval()
- avg_eval_loss = 0
- correct = 0
- total = 0
- eval_iters = 0
-
- for idx, (data, label) in enumerate(test_dataloader):
- with torch.no_grad():
- data = data[0].cuda()
- label = label[0].cuda()
-
- if level > 1:
- data = data.half()
-
- output = model(data)
- loss = loss_fn(output[0], label)
-
- avg_eval_loss += loss.detach().cpu().numpy()
- preds = torch.argmax(output[0], dim=1)
- total += data.size(0)
- correct += sum(preds == label)
- eval_iters += 1
-
- print_rank_0(f'epoch: {epoch}, eval loss: {avg_eval_loss / eval_iters}, acc: {correct / total}')
+ engine.backward(loss)
+ engine.step()
+ break
@pytest.mark.skip("This test should be invoked manually using the script provided")
diff --git a/tests/test_zero_data_parallel/test_zero.sh b/tests/test_zero_data_parallel/test_zero.sh
index b725f52aa..c1effa2d1 100644
--- a/tests/test_zero_data_parallel/test_zero.sh
+++ b/tests/test_zero_data_parallel/test_zero.sh
@@ -1,4 +1,4 @@
#!/bin/bash
test_file="test_zero.py"
-python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
+python $test_file --rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
diff --git a/tests/test_zero_tensor_parallel/components.py b/tests/test_zero_tensor_parallel/components.py
new file mode 100644
index 000000000..8421f2c8f
--- /dev/null
+++ b/tests/test_zero_tensor_parallel/components.py
@@ -0,0 +1,76 @@
+
+import sys
+from pathlib import Path
+repo_path = Path(__file__).absolute().parents[2]
+sys.path.append(str(repo_path))
+
+try:
+ import model_zoo.vit.vision_transformer_from_config
+except ImportError:
+ raise ImportError("model_zoo is not found, please check your path")
+
+BATCH_SIZE = 8
+IMG_SIZE = 32
+PATCH_SIZE = 4
+DIM = 512
+NUM_ATTENTION_HEADS = 8
+SUMMA_DIM = 2
+NUM_CLASSES = 10
+DEPTH = 6
+
+model_cfg = dict(
+ type='VisionTransformerFromConfig',
+ tensor_splitting_cfg=dict(
+ type='ViTInputSplitter2D',
+ ),
+ embedding_cfg=dict(
+ type='ViTPatchEmbedding2D',
+ img_size=IMG_SIZE,
+ patch_size=PATCH_SIZE,
+ embed_dim=DIM,
+ ),
+ token_fusion_cfg=dict(
+ type='ViTTokenFuser2D',
+ img_size=IMG_SIZE,
+ patch_size=PATCH_SIZE,
+ embed_dim=DIM,
+ drop_rate=0.1
+ ),
+ norm_cfg=dict(
+ type='LayerNorm2D',
+ normalized_shape=DIM,
+ eps=1e-6,
+ ),
+ block_cfg=dict(
+ type='ViTBlock',
+ attention_cfg=dict(
+ type='ViTSelfAttention2D',
+ hidden_size=DIM,
+ num_attention_heads=NUM_ATTENTION_HEADS,
+ attention_dropout_prob=0.,
+ hidden_dropout_prob=0.1,
+ ),
+ droppath_cfg=dict(
+ type='VanillaViTDropPath',
+ ),
+ mlp_cfg=dict(
+ type='ViTMLP2D',
+ in_features=DIM,
+ dropout_prob=0.1,
+ mlp_ratio=1
+ ),
+ norm_cfg=dict(
+ type='LayerNorm2D',
+ normalized_shape=DIM,
+ eps=1e-6,
+ ),
+ ),
+ head_cfg=dict(
+ type='ViTHead2D',
+ hidden_size=DIM,
+ num_classes=NUM_CLASSES,
+ ),
+ embed_dim=DIM,
+ depth=DEPTH,
+ drop_path_rate=0.,
+)
diff --git a/tests/test_zero_tensor_parallel/configs/vit_2d_zero1.py b/tests/test_zero_tensor_parallel/configs/vit_2d_zero1.py
deleted file mode 100644
index 61efa61ed..000000000
--- a/tests/test_zero_tensor_parallel/configs/vit_2d_zero1.py
+++ /dev/null
@@ -1,159 +0,0 @@
-import os
-from pathlib import Path
-
-import torch
-
-BATCH_SIZE = 512
-IMG_SIZE = 32
-PATCH_SIZE = 4
-DIM = 512
-NUM_ATTENTION_HEADS = 8
-SUMMA_DIM = 2
-NUM_CLASSES = 10
-DEPTH = 6
-
-train_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- transform_pipeline=[
- dict(type='RandomCrop', size=IMG_SIZE, padding=4),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True
- )
-)
-
-test_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]
- ),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True
- )
-)
-
-optimizer = dict(
- type='ZeroRedundancyOptimizer',
- optimizer_class=torch.optim.Adam,
- lr=0.001,
- weight_decay=0
-)
-
-optimizer = dict(
- type='Adam',
- lr=0.001,
- weight_decay=0
-)
-
-loss = dict(
- type='CrossEntropyLoss2D',
-)
-
-model = dict(
- type='VisionTransformerFromConfig',
- tensor_splitting_cfg=dict(
- type='ViTInputSplitter2D',
- ),
- embedding_cfg=dict(
- type='ViTPatchEmbedding2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- ),
- token_fusion_cfg=dict(
- type='ViTTokenFuser2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- drop_rate=0.1
- ),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- block_cfg=dict(
- type='ViTBlock',
- attention_cfg=dict(
- type='ViTSelfAttention2D',
- hidden_size=DIM,
- num_attention_heads=NUM_ATTENTION_HEADS,
- attention_dropout_prob=0.,
- hidden_dropout_prob=0.1,
- ),
- droppath_cfg=dict(
- type='VanillaViTDropPath',
- ),
- mlp_cfg=dict(
- type='ViTMLP2D',
- in_features=DIM,
- dropout_prob=0.1,
- mlp_ratio=1
- ),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- ),
- head_cfg=dict(
- type='ViTHead2D',
- hidden_size=DIM,
- num_classes=NUM_CLASSES,
- ),
- embed_dim=DIM,
- depth=DEPTH,
- drop_path_rate=0.,
-)
-
-parallel = dict(
- pipeline=dict(size=1),
- tensor=dict(size=4, mode='2d'),
-)
-
-from colossalai.engine import AMP_TYPE
-
-fp16 = dict(
- mode=AMP_TYPE.PARALLEL,
- initial_scale=2 ** 4
-)
-
-#
-# fp16 = dict(
-# mode=None,
-# )
-
-# both level 2 and 3 work
-# zero = dict(
-# type='ZeroRedundancyOptimizer_Level_1',
-# )
-
-lr_scheduler = dict(
- type='LinearWarmupLR',
- warmup_epochs=5
-)
-
-num_epochs = 60
diff --git a/tests/test_zero_tensor_parallel/configs/vit_2d_zero2.py b/tests/test_zero_tensor_parallel/configs/vit_2d_zero2.py
index 2ce42a88c..80c450a47 100644
--- a/tests/test_zero_tensor_parallel/configs/vit_2d_zero2.py
+++ b/tests/test_zero_tensor_parallel/configs/vit_2d_zero2.py
@@ -1,149 +1,12 @@
-import os
-from pathlib import Path
-
-BATCH_SIZE = 512
-IMG_SIZE = 32
-PATCH_SIZE = 4
-DIM = 512
-NUM_ATTENTION_HEADS = 8
-SUMMA_DIM = 2
-NUM_CLASSES = 10
-DEPTH = 6
-
-train_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- transform_pipeline=[
- dict(type='RandomCrop', size=IMG_SIZE, padding=4),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True
- )
-)
-
-test_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]
- ),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True
- )
-)
-
-optimizer = dict(
- type='Adam',
- lr=0.001,
- weight_decay=0
-)
-
-loss = dict(
- type='CrossEntropyLoss2D',
-)
-
-model = dict(
- type='VisionTransformerFromConfig',
- tensor_splitting_cfg=dict(
- type='ViTInputSplitter2D',
- ),
- embedding_cfg=dict(
- type='ViTPatchEmbedding2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- ),
- token_fusion_cfg=dict(
- type='ViTTokenFuser2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- drop_rate=0.1
- ),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- block_cfg=dict(
- type='ViTBlock',
- attention_cfg=dict(
- type='ViTSelfAttention2D',
- hidden_size=DIM,
- num_attention_heads=NUM_ATTENTION_HEADS,
- attention_dropout_prob=0.,
- hidden_dropout_prob=0.1,
- ),
- droppath_cfg=dict(
- type='VanillaViTDropPath',
- ),
- mlp_cfg=dict(
- type='ViTMLP2D',
- in_features=DIM,
- dropout_prob=0.1,
- mlp_ratio=1
- ),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- ),
- head_cfg=dict(
- type='ViTHead2D',
- hidden_size=DIM,
- num_classes=NUM_CLASSES,
- ),
- embed_dim=DIM,
- depth=DEPTH,
- drop_path_rate=0.,
-)
-
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
)
-# from colossalai.engine import AMP_TYPE
-#
-# fp16 = dict(
-# mode=AMP_TYPE.PARALLEL,
-# initial_scale=2 ** 4
-# )
-
fp16 = dict(
mode=None,
)
-# both level 2 and 3 work
zero = dict(
- type='ZeroRedundancyOptimizer_Level_2'
+ level=2
)
-
-lr_scheduler = dict(
- type='LinearWarmupLR',
- warmup_epochs=5
-)
-
-num_epochs = 60
diff --git a/tests/test_zero_tensor_parallel/configs/vit_2d_zero3.py b/tests/test_zero_tensor_parallel/configs/vit_2d_zero3.py
index 61f2a46f3..58e026347 100644
--- a/tests/test_zero_tensor_parallel/configs/vit_2d_zero3.py
+++ b/tests/test_zero_tensor_parallel/configs/vit_2d_zero3.py
@@ -1,149 +1,12 @@
-import os
-from pathlib import Path
-
-BATCH_SIZE = 512
-IMG_SIZE = 32
-PATCH_SIZE = 4
-DIM = 512
-NUM_ATTENTION_HEADS = 8
-SUMMA_DIM = 2
-NUM_CLASSES = 10
-DEPTH = 6
-
-train_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- transform_pipeline=[
- dict(type='RandomCrop', size=IMG_SIZE, padding=4),
- dict(type='RandomHorizontalFlip'),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True
- )
-)
-
-test_data = dict(
- dataset=dict(
- type='CIFAR10Dataset',
- root=Path(os.environ['DATA']),
- train=False,
- transform_pipeline=[
- dict(type='Resize', size=IMG_SIZE),
- dict(type='ToTensor'),
- dict(type='Normalize',
- mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]
- ),
- ]
- ),
- dataloader=dict(
- batch_size=BATCH_SIZE,
- pin_memory=True,
- num_workers=4,
- shuffle=True
- )
-)
-
-optimizer = dict(
- type='Adam',
- lr=0.001,
- weight_decay=0
-)
-
-loss = dict(
- type='CrossEntropyLoss2D',
-)
-
-model = dict(
- type='VisionTransformerFromConfig',
- tensor_splitting_cfg=dict(
- type='ViTInputSplitter2D',
- ),
- embedding_cfg=dict(
- type='ViTPatchEmbedding2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- ),
- token_fusion_cfg=dict(
- type='ViTTokenFuser2D',
- img_size=IMG_SIZE,
- patch_size=PATCH_SIZE,
- embed_dim=DIM,
- drop_rate=0.1
- ),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- block_cfg=dict(
- type='ViTBlock',
- attention_cfg=dict(
- type='ViTSelfAttention2D',
- hidden_size=DIM,
- num_attention_heads=NUM_ATTENTION_HEADS,
- attention_dropout_prob=0.,
- hidden_dropout_prob=0.1,
- ),
- droppath_cfg=dict(
- type='VanillaViTDropPath',
- ),
- mlp_cfg=dict(
- type='ViTMLP2D',
- in_features=DIM,
- dropout_prob=0.1,
- mlp_ratio=1
- ),
- norm_cfg=dict(
- type='LayerNorm2D',
- normalized_shape=DIM,
- eps=1e-6,
- ),
- ),
- head_cfg=dict(
- type='ViTHead2D',
- hidden_size=DIM,
- num_classes=NUM_CLASSES,
- ),
- embed_dim=DIM,
- depth=DEPTH,
- drop_path_rate=0.,
-)
-
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
)
-# from colossalai.engine import AMP_TYPE
-
-# fp16 = dict(
-# mode=AMP_TYPE.PARALLEL,
-# initial_scale=2 ** 4
-# )
-
fp16 = dict(
mode=None,
)
-# both level 2 and 3 work
zero = dict(
- type='ZeroRedundancyOptimizer_Level_3'
+ level=3
)
-
-lr_scheduler = dict(
- type='LinearWarmupLR',
- warmup_epochs=5
-)
-
-num_epochs = 60
diff --git a/tests/test_zero_tensor_parallel/test.sh b/tests/test_zero_tensor_parallel/test.sh
index 24d0c5423..da5afd5ae 100644
--- a/tests/test_zero_tensor_parallel/test.sh
+++ b/tests/test_zero_tensor_parallel/test.sh
@@ -1,4 +1,4 @@
#!/usr/bin/env sh
test_file=$1
-python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
+python $test_file --rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
diff --git a/tests/test_zero_tensor_parallel/test_vit_2d.py b/tests/test_zero_tensor_parallel/test_vit_2d.py
new file mode 100644
index 000000000..ef77e9f2e
--- /dev/null
+++ b/tests/test_zero_tensor_parallel/test_vit_2d.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import os
+from pathlib import Path
+
+import pytest
+import torch.autograd
+
+import colossalai
+import torch
+from colossalai.initialize import get_default_parser
+from colossalai.builder import build_model
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.utils import get_dataloader
+from colossalai.nn.layer._parallel_utilities import _gather
+from colossalai.nn import CrossEntropyLoss2D
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+from components import *
+
+level = os.environ['LEVEL']
+CONFIG_PATH = Path(__file__).parent.parent.joinpath(f'configs/vit_2d_zero{level}.py')
+
+
+def train_epoch(engine, train_dataloader):
+ engine.train()
+ accumulated_loss = 0
+ num_steps = len(train_dataloader)
+ data_iter = iter(train_dataloader)
+ for i in range(num_steps):
+ output, label, loss = engine.step(data_iter)
+ accumulated_loss += loss.detach().cpu().numpy()
+ avg_loss = accumulated_loss / num_steps
+ return avg_loss
+
+
+@pytest.mark.dist
+@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
+def test_2d_parallel_vision_transformer():
+ parser = get_default_parser()
+ args = parser.parse_args()
+ colossalai.launch(
+ config=CONFIG_PATH,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend
+ )
+
+ # build model
+ model = build_model(model_cfg)
+ model.build_from_cfg()
+
+ # build dataloader# build dataloaders
+ train_dataset = CIFAR10(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ]
+ )
+ )
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ drop_last=True)
+
+ # build optimizer and loss
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+ criterion = CrossEntropyLoss2D()
+
+ engine, train_dataloader, *args = colossalai.initialize(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader)
+ logger = get_dist_logger()
+
+ logger.info('start training')
+ engine.train()
+
+ for img, label in train_dataloader:
+ engine.zero_grad()
+ img = img.cuda()
+ label = label.cuda()
+ out = engine(img)
+ loss = engine.criterion(out, label)
+ engine.backward(loss)
+ engine.step()
+ break
+
+
+if __name__ == '__main__':
+ test_2d_parallel_vision_transformer()
diff --git a/tests/test_zero_tensor_parallel/test_vit_2d/test_vit_2d.py b/tests/test_zero_tensor_parallel/test_vit_2d/test_vit_2d.py
deleted file mode 100644
index 5c78dfcc2..000000000
--- a/tests/test_zero_tensor_parallel/test_vit_2d/test_vit_2d.py
+++ /dev/null
@@ -1,84 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import os
-from pathlib import Path
-
-import pytest
-import torch.autograd
-
-import colossalai
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.engine import Engine
-from colossalai.logging import get_global_dist_logger
-from colossalai.nn.layer._parallel_utilities import _gather
-
-level = os.environ['LEVEL']
-CONFIG_PATH = Path(__file__).parent.parent.joinpath(f'configs/vit_2d_zero{level}.py')
-
-
-def eval_epoch(engine: Engine, test_dataloader):
- engine.eval()
- accumulated_loss = 0
- correct_sum = 0
- total_sum = 0
- num_steps = len(test_dataloader)
- data_iter = iter(test_dataloader)
-
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
- accumulated_loss += loss.detach().cpu().numpy()
-
- output = _gather(
- output[0],
- ParallelMode.PARALLEL_2D_ROW,
- 1
- )
- output = _gather(
- output,
- ParallelMode.PARALLEL_2D_COL,
- 0,
- )
- output = torch.argmax(output, dim=-1)
- correct = torch.sum(label[0] == output)
- correct_sum += correct
- total_sum += label[0].size(0)
- avg_loss = accumulated_loss / num_steps
- return correct_sum, total_sum, avg_loss
-
-
-def train_epoch(engine, train_dataloader):
- engine.train()
- accumulated_loss = 0
- num_steps = len(train_dataloader)
- data_iter = iter(train_dataloader)
- for i in range(num_steps):
- output, label, loss = engine.step(data_iter)
- accumulated_loss += loss.detach().cpu().numpy()
- avg_loss = accumulated_loss / num_steps
- return avg_loss
-
-
-@pytest.mark.dist
-@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
-def test_2d_parallel_vision_transformer():
- # init dist
- engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
- logger = get_global_dist_logger()
-
- logger.info('start training')
- for epoch in range(gpc.config.num_epochs):
- train_loss = train_epoch(engine, train_dataloader)
-
- logger.info(f'epoch {epoch} - train loss: {train_loss}')
-
- if epoch % 2 == 0:
- correct_sum, total_sum, eval_loss = eval_epoch(engine, test_dataloader)
- logger.info(
- f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
- f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
-
-
-if __name__ == '__main__':
- test_2d_parallel_vision_transformer()