Develop/experiments (#59)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
Frank Lee 2021-12-09 15:08:29 +08:00 committed by GitHub
parent eb2f8b1f6b
commit da01c234e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
229 changed files with 6532 additions and 8741 deletions

View File

@ -1,4 +1,4 @@
from .initialize import init_dist, initialize from .initialize import (initialize, launch, launch_from_openmpi,
from .nn import * launch_from_slurm, launch_from_torch, get_default_parser)
__version__ = '0.0.1' __version__ = '0.0.1'

View File

@ -0,0 +1,32 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .amp_type import AMP_TYPE
from colossalai.context import Config
import torch.nn as nn
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from .torch_amp import convert_to_torch_amp
from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp
def convert_to_amp(model: nn.Module,
optimizer: Optimizer,
criterion: _Loss,
mode: AMP_TYPE,
amp_config: Config = None):
assert isinstance(mode, AMP_TYPE), \
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
if amp_config is None:
amp_config = Config()
if mode == AMP_TYPE.TORCH:
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
elif mode == AMP_TYPE.APEX:
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
elif mode == AMP_TYPE.NAIVE:
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
return model, optimizer, criterion

View File

@ -7,4 +7,4 @@ from enum import Enum
class AMP_TYPE(Enum): class AMP_TYPE(Enum):
APEX = 'apex' APEX = 'apex'
TORCH = 'torch' TORCH = 'torch'
PARALLEL = 'parallel' NAIVE = 'naive'

View File

@ -0,0 +1,15 @@
from .apex_amp import ApexAMPOptimizer
import torch.nn as nn
from torch.optim import Optimizer
import apex.amp as apex_amp
def convert_to_apex_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
optimizer = ApexAMPOptimizer(optimizer)
return model, optimizer
__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']

View File

@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
try:
import apex.amp as apex_amp
except:
pass
from torch import Tensor
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
class ApexAMPOptimizer(ColossalaiOptimizer):
def backward(self, loss: Tensor):
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
scaled_loss.backward()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if max_norm > 0:
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)

View File

@ -0,0 +1,20 @@
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
def convert_to_naive_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
if is_no_pp_or_last_stage():
model = NaiveAMPModel(model, output_to_fp32=True)
else:
model = NaiveAMPModel(model, output_to_fp32=False)
optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
return model, optimizer
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer']

View File

@ -12,11 +12,9 @@ from torch.optim import Optimizer
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_dist_logger
from colossalai.registry import OPTIMIZER_WRAPPERS from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
from colossalai.utils import print_rank_0 clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)
from ._utils import copy_tensor_parallel_attributes, clip_grad_norm_fp32, count_zeros_fp32
from ..multi_tensor_apply import multi_tensor_applier
def _zero_grad_group_helper(group, set_to_none): def _zero_grad_group_helper(group, set_to_none):
@ -92,7 +90,7 @@ class DynamicGradScaler:
self._growth_tracker = 0 self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis self._hysteresis_tracker = self.hysteresis
self._logger = get_global_dist_logger() self._logger = get_dist_logger()
@property @property
def scale(self): def scale(self):
@ -113,7 +111,7 @@ class DynamicGradScaler:
if self._hysteresis_tracker <= 0: if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor, self._scale = torch.max(self._scale * self.backoff_factor,
self.min_scale) self.min_scale)
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}') self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
else: else:
# If there is no nan/inf, increment the growth tracker. # If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1 self._growth_tracker += 1
@ -125,10 +123,10 @@ class DynamicGradScaler:
# and scale up the loss scale. # and scale up the loss scale.
if self._max_scale is not None and self._scale >= self._max_scale: if self._max_scale is not None and self._scale >= self._max_scale:
self._logger.info( self._logger.info(
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed') f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
else: else:
self._scale = self._scale * self.growth_factor self._scale = self._scale * self.growth_factor
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}') self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
@ -145,7 +143,6 @@ class DynamicGradScaler:
self._max_scale = state_dict['max_scale'] self._max_scale = state_dict['max_scale']
@OPTIMIZER_WRAPPERS.register_module
class FP16Optimizer(Optimizer): class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types. """Float16 optimizer for fp16 and bf16 data types.
@ -184,13 +181,13 @@ class FP16Optimizer(Optimizer):
max_scale: int = 2 ** 32): max_scale: int = 2 ** 32):
# default args for compatibility # default args for compatibility
bf16 = False bf16 = False
params_have_main_grad = False params_have_main_grad = True
# have a defaults for compatibility with pytorch optim # have a defaults for compatibility with pytorch optim
self.defaults = optimizer.defaults self.defaults = optimizer.defaults
# log config # log config
self._logger = get_global_dist_logger() self._logger = get_dist_logger()
self._logger.info(f"\n========= FP16 Optimizer Config =========\n" self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
f"Optimizer: {optimizer.__class__.__name__}\n" f"Optimizer: {optimizer.__class__.__name__}\n"
f"clip_grad = {clip_grad}\n" f"clip_grad = {clip_grad}\n"
@ -328,6 +325,7 @@ class FP16Optimizer(Optimizer):
else: else:
if model_param.grad is not None: if model_param.grad is not None:
main_param.grad = model_param.grad.float() main_param.grad = model_param.grad.float()
# For fp32 grads, we need to reset the grads to main grad. # For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad: if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups: for model_group in self.fp32_from_fp32_groups:
@ -387,10 +385,6 @@ class FP16Optimizer(Optimizer):
@torch.no_grad() @torch.no_grad()
def step(self): def step(self):
# for param_group in self.float16_groups:
# for param in param_group:
# print(param.grad is None)
# Copy gradients from model params to main params. # Copy gradients from model params to main params.
self._copy_model_grads_to_main_grads() self._copy_model_grads_to_main_grads()

View File

@ -0,0 +1,65 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, List, Any, Dict
from torch.optim import Optimizer
import torch.cuda.amp as torch_amp
from colossalai.nn.optimizer import ColossalaiOptimizer
from ._fp16_optimizer import FP16Optimizer
class NaiveAMPOptimizer(ColossalaiOptimizer):
def __init__(self, optim: Optimizer, *args, **kwargs):
optim = FP16Optimizer(optimizer=optim, *args, **kwargs)
super().__init__(optim)
def backward(self, loss: Tensor):
loss = self.optim.scale_loss(loss)
loss.backward()
def step(self):
self.optim.step()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
pass
class NaiveAMPModel(nn.Module):
def __init__(self,
model: nn.Module,
output_to_fp32: bool = True):
super().__init__()
self.model = model.half()
self._output_to_fp32 = output_to_fp32
def _convert_to_fp16(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
input_ = input_.half()
return input_
def _convert_to_fp32(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
input_ = input_.float()
return input_
def forward(self, *args, **kwargs):
if args:
args = [self._convert_to_fp16(arg) for arg in args]
if kwargs:
for k, v in kwargs.items():
kwargs[k] = self._convert_to_fp16(v)
out = self.model(*args, **kwargs)
if self._output_to_fp32:
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
return out

View File

@ -0,0 +1,18 @@
import torch.nn as nn
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from colossalai.context import Config
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
def convert_to_torch_amp(model: nn.Module,
optimizer: Optimizer,
criterion: _Loss,
amp_config: Config):
model = TorchAMPModel(model)
optimizer = TorchAMPOptimizer(optimizer, **amp_config)
criterion = TorchAMPLoss(criterion)
return model, optimizer, criterion
__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer']

View File

@ -1,4 +1,8 @@
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p #!/usr/bin/env python
# -*- encoding: utf-8 -*-
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
# to support tensor parallel
import torch import torch
from collections import defaultdict, abc from collections import defaultdict, abc
import warnings import warnings

View File

@ -0,0 +1,54 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
import torch.cuda.amp as torch_amp
from torch import Tensor
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from ._grad_scaler import GradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
class TorchAMPOptimizer(ColossalaiOptimizer):
def __init__(self, optim: Optimizer, *args, **kwargs):
super().__init__(optim)
self.scaler = GradScaler(*args, **kwargs)
def backward(self, loss: Tensor):
self.scaler.scale(loss).backward()
def step(self):
self.scaler.step(self.optim)
self.scaler.update()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if max_norm > 0.0:
self.scaler.unscale_(self.optim)
clip_grad_norm_fp32(model.parameters(), max_norm)
class TorchAMPModel(nn.Module):
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
@torch_amp.autocast()
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
class TorchAMPLoss(nn.Module):
def __init__(self, loss: _Loss):
super().__init__()
self.loss = loss
@torch_amp.autocast()
def forward(self, *args, **kwargs):
return self.loss(*args, **kwargs)

View File

@ -1,10 +1,10 @@
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper, from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler) build_gradient_handler)
from .pipeline import ModelInitializer from .pipeline import PipelineModelInitializer
__all__ = [ __all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper', 'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler', 'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
'build_gradient_handler', 'ModelInitializer' 'build_gradient_handler', 'PipelineModelInitializer'
] ]

View File

@ -106,7 +106,7 @@ def build_dataset(config):
return build_from_registry(config, DATASETS) return build_from_registry(config, DATASETS)
def build_optimizer(config, model, params: Iterable = None, need_module=False): def build_optimizer(config, model):
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`, """Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
'model' and 'params'. 'model' and 'params'.
@ -115,23 +115,12 @@ def build_optimizer(config, model, params: Iterable = None, need_module=False):
:type config: dict or :class:`colossalai.context.Config` :type config: dict or :class:`colossalai.context.Config`
:param model: A model containing parameters for the optimizer :param model: A model containing parameters for the optimizer
:type model: :class:`nn.Module` :type model: :class:`nn.Module`
:param params: A dict containing parameters for the optimizer
:type params: dict, optional
:param need_module: Indicates whether the optimizer needs a module
:type params: bool, optional
:raises AssertionError: Raises an AssertionError if both `model` and `params` are None
:return: An object of :class:`torch.optim.Optimizer` :return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer` :rtype: :class:`torch.optim.Optimizer`
""" """
assert model is not None or params is not None, 'arguments model and params can not both be None' config_ = config.copy()
if need_module: config_['params'] = model.parameters()
config['module'] = model return build_from_registry(config_, OPTIMIZERS)
elif model is not None:
config['params'] = model.parameters()
elif params is not None:
config['params'] = params
return build_from_registry(config, OPTIMIZERS)
def build_gradient_handler(config, model, optimizer): def build_gradient_handler(config, model, optimizer):
@ -149,8 +138,9 @@ def build_gradient_handler(config, model, optimizer):
:rtype: :class:`BaseGradientHandler` :rtype: :class:`BaseGradientHandler`
""" """
config_ = config.copy() config_ = config.copy()
mod_type = config_.pop('type') config_['model'] = model
return GRADIENT_HANDLER.get_module(mod_type)(model, optimizer, **config_) config_['optimizer'] = optimizer
return build_from_registry(config_, GRADIENT_HANDLER)
def build_hooks(config, trainer): def build_hooks(config, trainer):
@ -164,8 +154,9 @@ def build_hooks(config, trainer):
:return: An object of :class:`BaseHook` :return: An object of :class:`BaseHook`
:rtype: :class:`BaseHook` :rtype: :class:`BaseHook`
""" """
config['trainer'] = trainer config_ = config.copy()
return build_from_registry(config, HOOKS) config_['trainer'] = trainer
return build_from_registry(config_, HOOKS)
def build_transform(config): def build_transform(config):
@ -195,32 +186,8 @@ def build_data_sampler(config, dataset):
:rtype: :class:`colossalai.nn.data.sampler.BaseSampler` :rtype: :class:`colossalai.nn.data.sampler.BaseSampler`
""" """
config_ = config.copy() config_ = config.copy()
mod_type = config_.pop('type') config_['dataset'] = dataset
return SAMPLERS.get_module(mod_type)(dataset, **config_) return build_from_registry(config_, DATA_SAMPLERS)
def build_optimizer_wrapper(config, optimizer, model=None):
"""Returns an optimizer wrapper object of :class:`torch.optim.Optimizer` constructed
from `config`, `model` and `optimizer`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param optimizer: An optimizer object containing parameters for the gradient handler
:type optimizer: :class:`torch.optim.Optimizer`
:param model: A model containing parameters for the gradient handler
:type model: :class:`nn.Module`, optional
:return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer`
"""
config_ = config.copy()
mod_type = config_.pop('type')
# LSG: special treatment for zeor level 3
if mod_type == 'ZeroRedundancyOptimizer_Level_3':
return OPTIMIZER_WRAPPERS.get_module(mod_type)(model, optimizer, **config_)
else:
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
def build_lr_scheduler(config, optimizer): def build_lr_scheduler(config, optimizer):
@ -241,8 +208,8 @@ def build_lr_scheduler(config, optimizer):
:rtype: :class:`torch.optim.lr_scheduler` :rtype: :class:`torch.optim.lr_scheduler`
""" """
config_ = config.copy() config_ = config.copy()
mod_type = config_.pop('type') config_['optimizer'] = optimizer
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_) return build_from_registry(config_, LR_SCHEDULERS)
def build_schedule(config): def build_schedule(config):

View File

@ -4,7 +4,7 @@ import heapq
from colossalai.builder import build_model, build_layer from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import set_to_cuda from colossalai.utils import set_to_cuda
@ -111,21 +111,21 @@ def _binary_search(weights, num):
return intervals return intervals
def _partition_uniform(num_items, num_parts, num_chunks): def _partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \ assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended" "Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
logger = get_global_dist_logger() logger = get_dist_logger()
parts = [[] for _ in range(num_parts)] parts = [[] for _ in range(pipeline_parallel_size)]
partition_items = num_items // num_chunks partition_items = num_items // num_chunks
for idx in range(num_chunks): for idx in range(num_chunks):
base_idx = idx * partition_items base_idx = idx * partition_items
chunk_size = partition_items // num_parts chunk_size = partition_items // pipeline_parallel_size
left = num_parts - partition_items % num_parts left = pipeline_parallel_size - partition_items % pipeline_parallel_size
if chunk_size == 0: if chunk_size == 0:
logger.warning("Some nodes in Pipeline have no requests") logger.warning("Some nodes in Pipeline have no requests")
for p in range(num_parts): for p in range(pipeline_parallel_size):
st = base_idx st = base_idx
base_idx += chunk_size + (p >= left) base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx)) parts[p].append((st, base_idx))
@ -133,34 +133,34 @@ def _partition_uniform(num_items, num_parts, num_chunks):
return parts return parts
def _partition_balanced(weights, num_parts, num_chunks): def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
num_total = num_parts * num_chunks num_total = pipeline_parallel_size * num_chunks
num_items = len(weights) num_items = len(weights)
if num_items <= num_total: if num_items <= num_total:
return _partition_uniform(num_items, num_parts, num_chunks) return _partition_uniform(num_items, pipeline_parallel_size, num_chunks)
intervals = _binary_search(weights, num_total) intervals = _binary_search(weights, num_total)
current = 0 current = 0
parts = [[] for _ in range(num_parts)] parts = [[] for _ in range(pipeline_parallel_size)]
for inter in intervals: for inter in intervals:
parts[current].append(inter) parts[current].append(inter)
current = (current + 1) % num_parts current = (current + 1) % pipeline_parallel_size
return parts return parts
class ModelInitializer(): class PipelineModelInitializer():
def __init__(self, config, num_chunks, verbose=False): def __init__(self, config, num_chunks, verbose=False):
self.num_chunks = num_chunks self.num_chunks = num_chunks
self.ori_model = build_model(config) self.ori_model = build_model(config)
self.layers = self.ori_model.layers_cfg self.layers = self.ori_model.layers_cfg
layer_length = len(self.layers) layer_length = len(self.layers)
self.verbose = verbose self.verbose = verbose
self._logger = get_global_dist_logger() self._logger = get_dist_logger()
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0]) self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
def model_initialize(self, partition_method='parameter'): def initialize(self, partition_method='parameter'):
# Some space for initializing comunication groups # Some space for initializing comunication groups
self._interval = None self._interval = None
self._partition_layers(method=partition_method) self._partition_layers(method=partition_method)
@ -198,7 +198,7 @@ class ModelInitializer():
for st, ed in self.parts[stage]: for st, ed in self.parts[stage]:
for idx, layer in enumerate(self.layers[st: ed]): for idx, layer in enumerate(self.layers[st: ed]):
log_str += f'\t{idx + st:2d}: {layer}\n' log_str += f'\t{idx + st:2d}: {layer}\n'
self._logger.info(log_str) self._logger.info(log_str, ranks=[0])
# Save the partition # Save the partition
self._interval = self.parts[pipeline_rank] self._interval = self.parts[pipeline_rank]

View File

@ -1,4 +1,4 @@
from .collective import all_gather, reduce_scatter, scatter from .collective import all_gather, reduce_scatter, all_reduce
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
send_backward, send_backward_recv_backward, send_forward_recv_backward, send_backward, send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward, recv_backward) send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
@ -6,7 +6,7 @@ from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta from .utils import send_tensor_meta, recv_tensor_meta
__all__ = [ __all__ = [
'all_gather', 'reduce_scatter', 'scatter', 'all_gather', 'reduce_scatter', 'all_reduce',
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward', 'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward', 'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward', 'send_forward_recv_backward', 'recv_backward', 'recv_forward',

View File

@ -11,7 +11,7 @@ from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int, def all_gather(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor: parallel_mode: ParallelMode, async_op=False) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a """Gathers all tensors from the parallel group and concatenates them in a
specific dimension. specific dimension.
@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int,
""" """
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone() temp = tensor.clone()
shape = list(temp.shape) # shape = list(temp.shape)
shape[dim] *= depth # shape[dim] *= depth
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device()) # out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
out = list(torch.chunk(out, depth, dim=dim)) # out = list(torch.chunk(out, depth, dim=dim))
out = [val.contiguous() for val in out] # out = [val.contiguous() for val in out]
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode)) shape = [1] * len(tensor.shape)
out = torch.cat(out, dim=dim) shape[dim] = depth
return out out = tensor.repeat(shape)
out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
op = dist.all_gather(tensor_list=out,
tensor=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
# out = torch.cat(out, dim=dim)
if async_op:
return out, op
else:
return out
def reduce_scatter(tensor: Tensor, dim: int, def reduce_scatter(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor: parallel_mode: ParallelMode, async_op=False) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all """Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group. members in the parallel group.
@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int,
:rtype: Tensor :rtype: Tensor
""" """
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
temp = list(torch.chunk(tensor, depth, dim=dim)) # temp = list(torch.chunk(tensor, depth, dim=dim))
temp = [val.contiguous() for val in temp] # temp = [val.contiguous() for val in temp]
out = torch.empty(temp[0].shape, # out = torch.zeros(temp[0].shape,
dtype=temp[0].dtype, # dtype=temp[0].dtype,
device=get_current_device()) # device=get_current_device())
dist.reduce_scatter(output=out, temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
input_list=temp, out = temp[0].clone()
group=gpc.get_group(parallel_mode)) op = dist.reduce_scatter(output=out,
return out input_list=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return out, op
else:
return out
def scatter(tensor: Tensor, src: int, dim: int, def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode) -> Tensor: parallel_mode: ParallelMode,
"""Scatters in a specific dimension from source rank to all ranks in async_op=False) -> Tensor:
the parallel group. op = dist.all_reduce(tensor,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return tensor, op
else:
return tensor
# def scatter(tensor: Tensor, src: int, dim: int,
# parallel_mode: ParallelMode) -> Tensor:
# """Scatters in a specific dimension from source rank to all ranks in
# the parallel group.
:param tensor: Tensor to be scattered # :param tensor: Tensor to be scattered
:param dim: The dimension scattering in # :param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication # :param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor # :type tensor: Tensor
:type dim: int # :type dim: int
:type parallel_mode: ParallelMode # :type parallel_mode: ParallelMode
:return: The tensor generated by scatter # :return: The tensor generated by scatter
:rtype: Tensor # :rtype: Tensor
""" # """
depth = gpc.get_world_size(parallel_mode) # depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone() # temp = tensor.clone()
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode)) # dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
rank = gpc.get_local_rank(parallel_mode) # rank = gpc.get_local_rank(parallel_mode)
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous() # out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
return out # return out

View File

@ -17,8 +17,6 @@ def _communicate(tensor_send_next=None,
recv_next_shape=None, recv_next_shape=None,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None,
up_group=None,
down_group=None,
dtype=None): dtype=None):
""" """
Adapted from megatron.p2p_communication. Adapted from megatron.p2p_communication.
@ -59,60 +57,44 @@ def _communicate(tensor_send_next=None,
if prev_rank is None: if prev_rank is None:
prev_rank = gpc.get_prev_global_rank( prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE) ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
if tensor_send_next is not None or recv_next: if tensor_send_next is not None or recv_next:
if next_rank is None: if next_rank is None:
next_rank = gpc.get_next_global_rank( next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE) ParallelMode.PIPELINE)
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
# rank = dist.get_rank() # rank = dist.get_rank()
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
ops = [] ops = []
if tensor_send_prev is not None: if tensor_send_prev is not None:
send_prev_op = dist.broadcast(tensor_send_prev, send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
src=rank,
group=up_group,
async_op=True)
ops.append(send_prev_op) ops.append(send_prev_op)
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
recv_prev_op = dist.broadcast(tensor_recv_prev, recv_prev_op = dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)
src=prev_rank,
group=up_group,
async_op=True)
ops.append(recv_prev_op) ops.append(recv_prev_op)
if tensor_recv_next is not None: if tensor_recv_next is not None:
recv_next_op = dist.broadcast(tensor_recv_next, recv_next_op = dist.P2POp(dist.irecv, tensor_recv_next, next_rank)
src=next_rank,
group=down_group,
async_op=True)
ops.append(recv_next_op) ops.append(recv_next_op)
if tensor_send_next is not None: if tensor_send_next is not None:
send_next_op = dist.broadcast(tensor_send_next, send_next_op = dist.P2POp(dist.isend, tensor_send_next, next_rank)
src=rank,
group=down_group,
async_op=True)
ops.append(send_next_op) ops.append(send_next_op)
for req in ops: if len(ops) > 0:
req.wait() reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize() torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, up_group=None): def recv_forward(input_tensor_shape, prev_rank=None):
"""Receives the input tensor from the previous member in pipeline. """Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved :param input_tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor :param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_shape: torch.Size :type input_tensor_shape: torch.Size
:type prev_rank: int, optional :type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The input tensor in forward step :return: The input tensor in forward step
:rtype: Tensor :rtype: Tensor
""" """
@ -121,20 +103,17 @@ def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
else: else:
input_tensor, _ = _communicate(recv_prev=True, input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank, prev_rank=prev_rank)
up_group=up_group)
return input_tensor return input_tensor
def recv_backward(output_grad_shape, next_rank=None, down_group=None): def recv_backward(output_grad_shape, next_rank=None):
"""Receives the grad tensor from the next member in pipeline. """Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved :param output_grad_shape: The shape of the tensor to be recieved
:param next_rank: The rank of the source of the tensor :param next_rank: The rank of the source of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_grad_shape: torch.Size :type output_grad_shape: torch.Size
:type next_rank: int, optional :type next_rank: int, optional
:type down_group: ProcessGroup, optional
:return: The grad of output tensor in forward step :return: The grad of output tensor in forward step
:rtype: Tensor :rtype: Tensor
""" """
@ -143,56 +122,44 @@ def recv_backward(output_grad_shape, next_rank=None, down_group=None):
else: else:
_, output_tensor_grad = _communicate(recv_next=True, _, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
next_rank=next_rank, next_rank=next_rank)
down_group=down_group)
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, def send_forward(output_tensor, next_rank=None):
next_rank=None,
down_group=None):
"""Sends the input tensor to the next member in pipeline. """Sends the input tensor to the next member in pipeline.
:param output_tensor: Tensor to be sent :param output_tensor: Tensor to be sent
:param next_rank: The rank of the recipient of the tensor :param next_rank: The rank of the recipient of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_tensor: Tensor :type output_tensor: Tensor
:type next_rank: int, optional :type next_rank: int, optional
:type down_group: ProcessGroup, optional
""" """
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_last_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_next=output_tensor, _communicate(tensor_send_next=output_tensor,
next_rank=next_rank, next_rank=next_rank)
down_group=down_group)
def send_backward(input_tensor_grad, def send_backward(input_tensor_grad, prev_rank=None):
prev_rank=None,
up_group=None):
"""Sends the grad tensor to the previous member in pipeline. """Sends the grad tensor to the previous member in pipeline.
:param input_tensor_grad: Tensor to be sent :param input_tensor_grad: Tensor to be sent
:param prev_rank: The rank of the recipient of the tensor :param prev_rank: The rank of the recipient of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_grad: Tensor :type input_tensor_grad: Tensor
:type prev_rank: int, optional :type prev_rank: int, optional
:type up_group: ProcessGroup, optional
""" """
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_first_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_prev=input_tensor_grad, _communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank, prev_rank=prev_rank)
up_group=up_group)
def send_forward_recv_backward(output_tensor, def send_forward_recv_backward(output_tensor,
output_grad_shape, output_grad_shape,
recv_next=True, recv_next=True,
next_rank=None, next_rank=None):
down_group=None):
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the next member in pipeline, while recieves the grad tensor from the
next member in pipeline. next member in pipeline.
:param output_tensor: Tensor to be sent :param output_tensor: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved :param output_grad_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor :type output_tensor: Tensor
@ -206,20 +173,18 @@ def send_forward_recv_backward(output_tensor,
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor, _, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
recv_next=recv_next, recv_next=recv_next,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
next_rank=next_rank, next_rank=next_rank)
down_group=down_group)
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape, input_tensor_shape,
recv_prev=True, recv_prev=True,
prev_rank=None, prev_rank=None):
up_group=None):
"""Batched communication operation. Sends the grad tensor to the """Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the previous member in pipeline, while recieves the input tensor from the
previous member in pipeline. previous member in pipeline.
:param input_tensor_grad: Tensor to be sent :param input_tensor_grad: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved :param input_tensor_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor :type input_tensor_grad: Tensor
@ -233,8 +198,7 @@ def send_backward_recv_forward(input_tensor_grad,
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad, input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank, prev_rank=prev_rank)
up_group=up_group)
return input_tensor return input_tensor
@ -242,13 +206,11 @@ def send_forward_recv_forward(output_tensor,
input_tensor_shape, input_tensor_shape,
recv_prev=True, recv_prev=True,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None):
up_group=None,
down_group=None):
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the next member in pipeline, while recieves the input tensor from the
previous member in pipeline. previous member in pipeline.
:param output_tensor: Tensor to be sent :param output_tensor: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved :param input_tensor_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor :type output_tensor: Tensor
@ -260,9 +222,7 @@ def send_forward_recv_forward(output_tensor,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank, next_rank=next_rank)
up_group=up_group,
down_group=down_group)
return input_tensor return input_tensor
@ -270,13 +230,11 @@ def send_backward_recv_backward(input_tensor_grad,
output_grad_shape, output_grad_shape,
recv_next=True, recv_next=True,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None):
up_group=None,
down_group=None):
"""Batched communication operation. Sends the grad tensor to the """Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the previous member in pipeline, while recieves the grad tensor from the
next member in pipeline. next member in pipeline.
:param input_tensor_grad: Tensor to be sent :param input_tensor_grad: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved :param output_grad_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor :type input_tensor_grad: Tensor
@ -288,9 +246,7 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next=recv_next, recv_next=recv_next,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank, next_rank=next_rank)
up_group=up_group,
down_group=down_group)
return output_tensor_grad return output_tensor_grad
@ -301,13 +257,11 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev=True, recv_prev=True,
recv_next=True, recv_next=True,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None):
up_group=None,
down_group=None):
"""Batched communication operation. Sends the input tensor to the next and """Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous. next and the input tensor from the previous.
:param output_tensor: Tensor sent to the next :param output_tensor: Tensor sent to the next
:param input_tensor_grad: Tensor sent to the previous :param input_tensor_grad: Tensor sent to the previous
:param input_tensor_shape: The shape of the tensor recieved from the previous :param input_tensor_shape: The shape of the tensor recieved from the previous
@ -327,7 +281,5 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank, next_rank=next_rank)
up_group=up_group,
down_group=down_group)
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad

View File

@ -6,7 +6,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
def send_tensor_meta(tensor, need_meta=True, down_group=None): def send_tensor_meta(tensor, need_meta=True, next_rank=None):
"""Sends tensor meta information before sending a specific tensor. """Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function meta information of the tensor should be sent before communications. This function
@ -14,31 +14,34 @@ def send_tensor_meta(tensor, need_meta=True, down_group=None):
:param tensor: Tensor to be sent :param tensor: Tensor to be sent
:param need_meta: If False, meta information won't be sent :param need_meta: If False, meta information won't be sent
:param down_group: Communication group including the next member in pipeline parallel group :param next_rank: The rank of the next member in pipeline parallel group
:type tensor: Tensor :type tensor: Tensor
:type need_meta: bool, optional :type need_meta: bool, optional
:type down_group: ProcessGroup, optional :type next_rank: int
:return: False :return: False
:rtype: bool :rtype: bool
""" """
if need_meta: if need_meta:
rank = gpc.get_global_rank() if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
send_shape = torch.tensor(tensor.size(), **tensor_kwargs) send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs) send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
ops = [
dist.broadcast(send_ndims, src=rank, group=down_group) dist.P2POp(dist.isend, send_ndims, next_rank),
dist.broadcast(send_shape, src=rank, group=down_group) dist.P2POp(dist.isend, send_shape, next_rank)
]
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
return False return False
def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None): def recv_tensor_meta(tensor_shape, prev_rank=None):
"""Recieves tensor meta information before recieving a specific tensor. """Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function meta information of the tensor should be recieved before communications. This function
@ -46,27 +49,21 @@ def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
:param tensor_shape: The shape of the tensor to be recieved :param tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor :param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type tensor_shape: torch.Size :type tensor_shape: torch.Size
:type prev_rank: int, optional :type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The shape of the tensor to be recieved :return: The shape of the tensor to be recieved
:rtype: torch.Size :rtype: torch.Size
""" """
if tensor_shape is None: if tensor_shape is None:
if prev_rank is None: if prev_rank is None:
prev_rank = gpc.get_prev_global_rank( prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
recv_ndims = torch.empty((), **tensor_kwargs) recv_ndims = torch.empty((), **tensor_kwargs)
dist.broadcast(recv_ndims, src=prev_rank, group=up_group) dist.recv(recv_ndims, prev_rank)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs) recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.broadcast(recv_shape, src=prev_rank, group=up_group) dist.recv(recv_shape, prev_rank)
tensor_shape = torch.Size(recv_shape) tensor_shape = torch.Size(recv_shape)

View File

@ -25,7 +25,11 @@ TESSERACT_DEP = 'TESSERACT_DEP'
# 3D parallel # 3D parallel
DEPTH_3D = 'DEPTH_3D' DEPTH_3D = 'DEPTH_3D'
INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'
# Tensor parallel attributes # Tensor parallel attributes
IS_TENSOR_PARALLEL = 'is_tensor_parallel' IS_TENSOR_PARALLEL = 'is_tensor_parallel'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL] NUM_PARTITIONS = 'num_partitions'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]

View File

@ -1,5 +1,5 @@
from .config import Config from .config import Config, ConfigException
from .parallel_context import ParallelContext from .parallel_context import ParallelContext
from .parallel_context import ParallelMode from .parallel_mode import ParallelMode
from .process_group_initializer import * from .process_group_initializer import *
from .random import * from .random import *

View File

@ -1,70 +0,0 @@
import math
def set_parallel_size(obj, config: dict, key: str, attr_name: str):
if key in config:
ele = config[key]
if isinstance(ele, int):
setattr(obj, attr_name, ele)
elif isinstance(ele, dict):
setattr(obj, attr_name, ele['size'])
else:
raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict"
)
def add_tensor_pg(pg_init, mode, size, depth=None):
if mode == '1d':
pg_init.append(dict(
type='Initializer1D',
parallel_size=size
))
elif mode == '2d':
dim = math.floor(math.sqrt(size))
pg_init.append(dict(
type='Initializer2D_Col',
summa_dim=dim
))
pg_init.append(dict(
type='Initializer2D_Row',
summa_dim=dim
))
elif mode == '2.5d':
dim = math.floor(math.sqrt(size // depth))
pg_init.append(dict(
type='Initializer_Tesseract_ROW',
tesseract_dim=dim,
tesseract_dep=depth
))
pg_init.append(dict(
type='Initializer_Tesseract_COL',
tesseract_dim=dim,
tesseract_dep=depth
))
pg_init.append(dict(
type='Initializer_Tesseract_DEP',
tesseract_dim=dim,
tesseract_dep=depth
))
pg_init.append(dict(
type='Initializer_Tesseract_XZ',
tesseract_dim=dim,
tesseract_dep=depth
))
elif mode == '3d':
dim = math.floor(math.pow(size, 1.0 / 3.0) + 0.5)
pg_init.append(dict(
type='ParallelInitializer3D_Input',
depth=dim
))
pg_init.append(dict(
type='ParallelInitializer3D_Weight',
depth=dim
))
pg_init.append(dict(
type='ParallelInitializer3D_Output',
depth=dim
))
else:
raise NotImplementedError("This kind of tensor splitting has not been implemented yet")

View File

@ -97,3 +97,7 @@ class Config(dict):
sys.path.pop(0) sys.path.pop(0)
return config return config
class ConfigException(Exception):
pass

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import random import random
from typing import Union from typing import Union
@ -11,8 +10,8 @@ import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config from colossalai.context.config import Config
from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from ._utils import set_parallel_size
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode from .random import add_seed, get_seeds, set_mode
@ -21,11 +20,24 @@ class ParallelContext:
"""This class provides interface functions for users to get the parallel context, """This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device. such as the global rank, the local rank, the world size, etc. of each device.
:param args: The distributed arguments in the system
:type args: dict
""" """
def __init__(self, args=None): __instance = None
@staticmethod
def get_instance():
if ParallelContext.__instance is None:
ParallelContext()
return ParallelContext.__instance
def __init__(self):
# create a singleton instance
if ParallelContext.__instance is not None:
raise Exception(
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
else:
ParallelContext.__instance = self
# distributed settings # distributed settings
self._global_ranks = dict() self._global_ranks = dict()
self._local_ranks = dict() self._local_ranks = dict()
@ -34,7 +46,6 @@ class ParallelContext:
self._ranks_in_group = dict() self._ranks_in_group = dict()
# load config from file # load config from file
self._dist_args = args
self._config = None self._config = None
# default 3D parallel args, will be overwritten during process group intialization # default 3D parallel args, will be overwritten during process group intialization
@ -43,10 +54,22 @@ class ParallelContext:
self.pipeline_parallel_size = 1 self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1 self.tensor_parallel_size = 1
# logging
self._verbose = False
self._logger = get_dist_logger()
@property @property
def config(self): def config(self):
return self._config return self._config
@property
def verbose(self):
return self._verbose
@verbose.setter
def verbose(self, verbose_: bool):
self._verbose = verbose_
def load_config(self, config: Union[dict, str]): def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file. """Loads the configuration from either a dict or a file.
@ -62,14 +85,6 @@ class ParallelContext:
else: else:
raise TypeError("Invalid type for config, only dictionary or string is supported") raise TypeError("Invalid type for config, only dictionary or string is supported")
def set_dist_args(self, args):
"""Sets the distributed arguments.
:param args: The distributed arguments in the system
:type args: dict
"""
self._dist_args = args
@staticmethod @staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode): def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode) assert isinstance(parallel_mode, ParallelMode)
@ -268,32 +283,36 @@ class ParallelContext:
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self, addr=None, port=None): def init_global_dist(self,
"""Initializes the global distributed environment. rank: int,
world_size: int,
:param addr: The IP address of the current device backend: str,
:type addr: str, optional host: str,
:param port: The port to be used in the system of the current device port: int
:type port: int, optional ):
"""Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int
:param world_size: world size of the default process group
:type world_size: int
:param host: the master address for distributed training
:type host: str
:param port: the master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
""" """
# get config # initialize the default process group
rank = self._dist_args.local_rank init_method = f'tcp://{host}:{port}'
world_size = self._dist_args.world_size dist.init_process_group(rank=rank,
# default env config, overwrite by exporting
# them in your bash script
addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr
port = os.getenv('MASTER_PORT', '8008') if port is None else port
init_method = f'tcp://{addr}:{port}'
dist.init_process_group(backend=self._dist_args.backend,
rank=rank,
world_size=world_size, world_size=world_size,
backend=backend,
init_method=init_method) init_method=init_method)
# None will give the default global process group for pytorch dist operations # None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None, self._register_dist(rank, world_size, None,
list(range(world_size)), ParallelMode.GLOBAL) list(range(world_size)), ParallelMode.GLOBAL)
self._global_ranks[ParallelMode.GLOBAL] = rank self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size, def _register_dist(self, local_rank, world_size,
process_group, ranks_in_group, mode): process_group, ranks_in_group, mode):
@ -312,7 +331,20 @@ class ParallelContext:
pps = self.pipeline_parallel_size pps = self.pipeline_parallel_size
tps = self.tensor_parallel_size tps = self.tensor_parallel_size
ws = self.world_size ws = self.world_size
assert ws == dps * pps * tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})" assert ws == dps * pps * \
tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config:
ele = config[key]
if isinstance(ele, int):
setattr(self, attr_name, ele)
elif isinstance(ele, dict):
setattr(self, attr_name, ele['size'])
else:
raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict"
)
def init_parallel_groups(self): def init_parallel_groups(self):
"""Initializes the parallel groups. """Initializes the parallel groups.
@ -325,21 +357,20 @@ class ParallelContext:
world_size = self.get_world_size(ParallelMode.GLOBAL) world_size = self.get_world_size(ParallelMode.GLOBAL)
self.world_size = world_size self.world_size = world_size
assert hasattr(self.config, 'parallel'), 'Expected the field parallel to be present in the config file'
# set parallel size as attributes for global context # set parallel size as attributes for global context
parallel_config = self.config.parallel parallel_config = self.config.get('parallel', None)
set_parallel_size(self, parallel_config, 'pipeline', if parallel_config is not None:
'pipeline_parallel_size') self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size')
set_parallel_size(self, parallel_config, 'tensor', self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size')
'tensor_parallel_size')
# the user should not set the data parallel size manually # the user should not set the data parallel size manually
# instead, it should be calculated based on other parallel config # instead, it should be calculated based on other parallel config
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size) self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
# get the tensor parallel mode and check # get the tensor parallel mode and check
tensor_parallel_mode = parallel_config['tensor'].get('mode', None) tensor_parallel_mode = None
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
self.check_sanity() self.check_sanity()
@ -400,23 +431,21 @@ class ParallelContext:
# destroy global process group # destroy global process group
dist.destroy_process_group() dist.destroy_process_group()
def set_device(self): def set_device(self, device_ordinal: int = None):
"""Sets distributed processes to be bound to devices. """Sets distributed processes to be bound to devices.
""" """
devices_per_node = torch.cuda.device_count()
global_rank = self.get_global_rank() global_rank = self.get_global_rank()
device = global_rank % devices_per_node if device_ordinal is None:
torch.cuda.set_device(device) devices_per_node = torch.cuda.device_count()
print(f'process rank {global_rank} is bound to device {device}') device_ordinal = global_rank % devices_per_node
def set_seed(self): torch.cuda.set_device(device_ordinal)
if self._verbose:
self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}')
def set_seed(self, seed: int):
"""Sets seeds for all random libraries. """Sets seeds for all random libraries.
""" """
if hasattr(self.config, 'seed'):
seed = getattr(self.config, 'seed')
else:
seed = 2 # default seed
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
@ -444,11 +473,18 @@ class ParallelContext:
seeds = get_seeds() seeds = get_seeds()
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
print(f"initialized seed on rank {global_rank}, " if self._verbose:
f"numpy: {seed}, python random: {seed}, {seed_str}," self._logger.info(
f"the default parallel seed is {ParallelMode.DATA}.", flush=True) f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.",
ranks=[0])
else: else:
print(f"initialized seed on rank {global_rank}, " if self._verbose:
f"numpy: {seed}, python random: {seed}, pytorch: {seed}", flush=True) self._logger.info(
print('WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states', f"initialized seed on rank {global_rank}, "
flush=True) f"numpy: {seed}, python random: {seed}, pytorch: {seed}",
ranks=[0])
self._logger.info(
'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
ranks=[0])

View File

@ -4,7 +4,6 @@
import torch.distributed as dist import torch.distributed as dist
from colossalai.context import Config from colossalai.context import Config
from colossalai.core import global_context as gpc
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode

View File

@ -8,7 +8,6 @@ import torch.distributed as dist
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
from colossalai.context import Config from colossalai.context import Config
from colossalai.core import global_context as gpc
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
@ -42,8 +41,6 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
tesseract_dep: int, tesseract_dep: int,
*args): *args):
super(Initializer_2p5D_ROW, self).__init__(*args) super(Initializer_2p5D_ROW, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim self.tesseract_dim = tesseract_dim
@ -66,7 +63,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
for j in range(self.tesseract_dim): for j in range(self.tesseract_dim):
for k in range(self.tesseract_dep): for k in range(self.tesseract_dep):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for i in range(self.tesseract_dim)] j + self.tesseract_dim * k) for i in range(self.tesseract_dim)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
if self.rank in ranks: if self.rank in ranks:
@ -81,13 +78,12 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
class Initializer_2p5D_Col(ProcessGroupInitializer): class Initializer_2p5D_Col(ProcessGroupInitializer):
'''2p5d tensor parallel initialization among cols. '''2p5d tensor parallel initialization among cols.
''' '''
def __init__(self, def __init__(self,
tesseract_dim: int, tesseract_dim: int,
tesseract_dep: int, tesseract_dep: int,
*args): *args):
super(Initializer_2p5D_Col, self).__init__(*args) super(Initializer_2p5D_Col, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim self.tesseract_dim = tesseract_dim
@ -110,7 +106,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
for i in range(self.tesseract_dim): for i in range(self.tesseract_dim):
for k in range(self.tesseract_dep): for k in range(self.tesseract_dep):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for j in range(self.tesseract_dim)] j + self.tesseract_dim * k) for j in range(self.tesseract_dim)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
if self.rank in ranks: if self.rank in ranks:
@ -125,13 +121,12 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
class Initializer_2p5D_Dep(ProcessGroupInitializer): class Initializer_2p5D_Dep(ProcessGroupInitializer):
'''2p5D tensor parallel initialization among depths. '''2p5D tensor parallel initialization among depths.
''' '''
def __init__(self, def __init__(self,
tesseract_dim: int, tesseract_dim: int,
tesseract_dep: int, tesseract_dep: int,
*args): *args):
super(Initializer_2p5D_Dep, self).__init__(*args) super(Initializer_2p5D_Dep, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim self.tesseract_dim = tesseract_dim
@ -154,7 +149,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
for i in range(self.tesseract_dim): for i in range(self.tesseract_dim):
for j in range(self.tesseract_dim): for j in range(self.tesseract_dim):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for k in range(self.tesseract_dep)] j + self.tesseract_dim * k) for k in range(self.tesseract_dep)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
if self.rank in ranks: if self.rank in ranks:
@ -170,13 +165,12 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
class Initializer_2p5D_XZ(ProcessGroupInitializer): class Initializer_2p5D_XZ(ProcessGroupInitializer):
'''2p5d tensor parallel initialization among cols times dep. '''2p5d tensor parallel initialization among cols times dep.
''' '''
def __init__(self, def __init__(self,
tesseract_dim: int, tesseract_dim: int,
tesseract_dep: int, tesseract_dep: int,
*args): *args):
super(Initializer_2p5D_XZ, self).__init__(*args) super(Initializer_2p5D_XZ, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim self.tesseract_dim = tesseract_dim
@ -198,8 +192,8 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
for h in range(self.num_group): for h in range(self.num_group):
for i in range(self.tesseract_dim): for i in range(self.tesseract_dim):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * ( ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in
range(self.tesseract_dim)] range(self.tesseract_dim)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
if self.rank in ranks: if self.rank in ranks:

View File

@ -5,7 +5,7 @@ import math
import os import os
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import DEPTH_3D from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
@ -18,7 +18,7 @@ def _check_depth_env_var(depth):
if env_depth: if env_depth:
assert int(env_depth) == depth, \ assert int(env_depth) == depth, \
'SUMMA_DIM has been set in the current environment and ' \ 'DEPTH_3D has been set in the current environment and ' \
'does not match with the value passed to this initialized' 'does not match with the value passed to this initialized'
else: else:
os.environ[DEPTH_3D] = str(depth) os.environ[DEPTH_3D] = str(depth)
@ -43,6 +43,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT mode = ParallelMode.PARALLEL_3D_INPUT
os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D
for h in range(self.num_group): for h in range(self.num_group):
for i in range(self.depth): for i in range(self.depth):
@ -82,6 +83,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_3D_WEIGHT mode = ParallelMode.PARALLEL_3D_WEIGHT
os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D
for h in range(self.num_group): for h in range(self.num_group):
for k in range(self.depth): for k in range(self.depth):
@ -121,6 +123,7 @@ class Initializer_3D_Output(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT mode = ParallelMode.PARALLEL_3D_OUTPUT
os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D
for h in range(self.num_group): for h in range(self.num_group):
for i in range(self.depth): for i in range(self.depth):

View File

@ -3,14 +3,4 @@
from colossalai.context import ParallelContext from colossalai.context import ParallelContext
global_context = ParallelContext() global_context = ParallelContext.get_instance()
def set_global_context(context: ParallelContext):
'''Reset global context to be identical to a given :class:ParallelContext.
:param context: Parallel context to generate our global parallel context.
:type context: ParallelContext
'''
global global_context
global_context = context

View File

@ -1,7 +1,5 @@
from ._base_engine import Engine from ._base_engine import Engine
from .gradient_handler import * from .gradient_handler import *
from .schedule import *
from .amp import *
__all__ = ['Engine'] __all__ = ['Engine']

View File

@ -1,17 +1,17 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch
from typing import List
from torch.nn import Module from torch.nn import Module
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.builder import build_gradient_handler from colossalai.builder import build_gradient_handler
from colossalai.context import ParallelMode from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc from colossalai.utils import is_using_ddp, is_using_pp
from colossalai.logging import get_global_dist_logger from torch import Tensor
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from .schedule import BaseSchedule
class Engine: class Engine:
@ -20,74 +20,40 @@ class Engine:
It controls a iteration in training. It controls a iteration in training.
:param model: The neural network model :param model: The neural network model
:type model: ``torch.nn.Module``
:param optimizer: Optimizer for updating the parameters :param optimizer: Optimizer for updating the parameters
:param step_schedule: Running schedule in :meth:`step` :type optimizer: ``torch.optim.Optimizer``
:param gradient_accumulation: Steps of gradient accumulation :param criterion: Loss function for calculating loss
:type criterion: ``torch.nn.modules.loss._Loss``
:param gradient_clipping: The norm of gradient clipping :param gradient_clipping: The norm of gradient clipping
:type model: Module
:type optimizer: Optimizer
:type step_schedule: BaseSchedule, optional
:type gradient_accumulation: int, optional
:type gradient_clipping: float, optional :type gradient_clipping: float, optional
:param verbose: whether to display log info
:type verbose: bool
""" """
def __init__(self, def __init__(self,
model: Module, model: Module,
optimizer: Optimizer, optimizer: Optimizer,
criterion: _Loss, criterion: _Loss,
step_schedule: BaseSchedule, gradient_handlers: List = None,
gradient_handlers: list = None, clip_grad_norm: float = 0.0,
gradient_accumulation: int = 1, verbose: bool = True
gradient_clipping: float = 0.0,
): ):
self._model = model self._model = model
self._optimizer = optimizer self._optimizer = optimizer
self._criterion = criterion self._criterion = criterion
self._schedule = step_schedule self._clip_grad_norm = clip_grad_norm
self._verbose = verbose
# schedule initialize self._logger = get_dist_logger()
self._schedule.initialize(model, optimizer)
# state # state
self.training = True # default self.training = True # default
# gradient accumulation
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
self._grad_accum_size = gradient_accumulation
self._grad_clip = gradient_clipping
self._logger = get_global_dist_logger()
# build gradient handler # build gradient handler
self._gradient_handlers = [] if gradient_handlers:
self._gradient_handlers = gradient_handlers
if gradient_handlers is not None:
assert isinstance(gradient_handlers, list), \
f'argument gradient_handler_cfg expected type list, ' \
f'but got type {type(gradient_handlers)}'
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handlers = [dict(type='ZeROGradientHandler')]
self._logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
ParallelMode.DATA) > 1:
gradient_handlers = [dict(type='DataParallelGradientHandler')]
self._logger.info(
"Data parallel training is detected, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
if gradient_handlers is None:
self._logger.warning(
"No gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.",
ranks=[0])
else: else:
for cfg in gradient_handlers: self._gradient_handlers = []
handler = build_gradient_handler(cfg, model, optimizer)
self._gradient_handlers.append(handler)
@property @property
def model(self): def model(self):
@ -105,11 +71,27 @@ class Engine:
def schedule(self): def schedule(self):
return self._schedule return self._schedule
@property def zero_grad(self):
def gradient_accumulation(self): self.optimizer.zero_grad()
return self._grad_accum_size
def handle_gradient(self): def step(self):
self._all_reduce_gradients()
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
self.optimizer.step()
def backward(self, loss: Tensor):
return self.optimizer.backward(loss)
def backward_by_grad(self, tensor, grad):
return self.optimizer.backward_by_grad(tensor, grad)
def calc_loss(self, *args, **kwargs):
return self.criterion(*args, **kwargs)
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
def _all_reduce_gradients(self):
"""Handles all-reduce operations of gradients across different parallel groups. """Handles all-reduce operations of gradients across different parallel groups.
""" """
for handler in self._gradient_handlers: for handler in self._gradient_handlers:
@ -126,51 +108,3 @@ class Engine:
""" """
self.training = False self.training = False
self._model.eval() self._model.eval()
def step(self,
data_iter,
is_last_iteration: bool = False,
return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or
evaluation over a batch of dataset.
:param data_iter: Data iterator of the dataset
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
:param return_loss: loss will be returned if True
:type data_iter: Iterator
:type is_last_iteration: bool, optional
:type return_loss: bool, optional
:return: (output, lablel, loss)
"""
if self.training:
self._optimizer.zero_grad()
# differentiate training and eval with grad accum
if self.training:
for i in range(self._grad_accum_size):
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=False,
grad_accum_size=self._grad_accum_size,
return_loss=return_loss)
if i == self._grad_accum_size - 1:
# all reduce gradients
self.handle_gradient()
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
else:
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=True,
grad_accum_size=1,
return_loss=return_loss)
# consume the remaining dataset left out due to gradient accumulation
if is_last_iteration:
while True:
try:
_ = next(data_iter)
except StopIteration:
break
return output, label, loss

View File

@ -1,2 +0,0 @@
from .grad_scaler import GradScaler
from .amp_type import AMP_TYPE

View File

@ -1,5 +1,5 @@
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
from ._no_pipeline import NoPipelineSchedule from ._pipeline_schedule import PipelineSchedule
from ._pipeline import PipelineSchedule from ._non_pipeline_schedule import NonPipelineSchedule
__all__ = ['BaseSchedule', 'NoPipelineSchedule', 'PipelineSchedule'] __all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule']

View File

@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
import torch import torch
from colossalai.core import global_context as gpc from torch import Tensor
from colossalai.logging import get_global_dist_logger from typing import Iterable, Union, List, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -18,8 +20,9 @@ class BaseSchedule(ABC):
control of FP16 in class schedule. control of FP16 in class schedule.
""" """
def __init__(self): def __init__(self, batch_data_process_func: Callable = None):
self.logger = get_global_dist_logger() self.logger = get_dist_logger()
self.batch_data_process_func = batch_data_process_func
@staticmethod @staticmethod
def _move_tensor(element): def _move_tensor(element):
@ -35,6 +38,11 @@ class BaseSchedule(ABC):
data = data.to(get_current_device()).detach() data = data.to(get_current_device()).detach()
return data return data
def _to_list(self, data):
if torch.is_tensor(data):
return [data]
return data
def load_batch(self, data_iter): def load_batch(self, data_iter):
"""Loads a batch from data iterator. It returns the data and labels which are """Loads a batch from data iterator. It returns the data and labels which are
already in the same GPU as where the model's. already in the same GPU as where the model's.
@ -44,46 +52,34 @@ class BaseSchedule(ABC):
""" """
if data_iter is None: if data_iter is None:
raise RuntimeError('Dataloader is not defined.') raise RuntimeError('Dataloader is not defined.')
data, label = next(data_iter) batch_data = next(data_iter)
if self.batch_data_process_func:
data, label = self.batch_data_process_func(batch_data)
else:
data, label = batch_data
data, label = self._to_list(data), self._to_list(label)
return self._move_to_device(data), self._move_to_device(label) return self._move_to_device(data), self._move_to_device(label)
def initialize(self, model, optimizer): def pre_processing(self, engine: Engine):
"""Initializes the model and the optimizer before training. """To perform actions before running the schedule.
This is often used in FP16 training.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
""" """
return model, optimizer pass
@abstractmethod @abstractmethod
def forward_backward_step(self, def forward_backward_step(self,
data_iter, engine: Engine,
model, data_iter: Iterable,
criterion, forward_only: bool,
optimizer=None, return_loss: bool = True
forward_only=False, ):
grad_accum_size: int = 1,
return_loss=True):
"""The process function over a batch of dataset for training or evaluation. """The process function over a batch of dataset for training or evaluation.
:param data_iter: Data iterator of the dataset :param engine: Colossalai training engine
:param model: Model used in training or evaluation :param inputs: input data
:param optimizer: Optimizer used in training or evaluation :param labels: ground truth
:param criterion: Loss function
:param forward_only: If True, the process won't include backward :param forward_only: If True, the process won't include backward
:param grad_accum_size: Steps of gradient accumulation
:param return_loss: If False, the loss won't be returned :param return_loss: If False, the loss won't be returned
""" """
pass pass
@abstractmethod
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
"""Updates the parameters with the optimizer.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
:param grad_clipping: The norm of gradient clipping
:type grad_clipping: float, optional
"""
pass

View File

@ -1,188 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
try:
import apex.amp as apex_amp
except:
pass
try:
import torch.cuda.amp as torch_amp
except:
pass
from typing import Iterable
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16, convert_to_fp32
from ..amp import AMP_TYPE, GradScaler
class NoPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(
self,
amp_type: AMP_TYPE = None,
amp_config: dict = None,
):
super().__init__()
# mixed precision training
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
'unrecognised value for argument fp16, it can only be None, torch or apex'
self.use_zero_level_2_3 = False
if amp_type is not None:
self.fp16 = True
self.amp_type = amp_type
if amp_config is not None:
assert isinstance(amp_config, dict), \
f'expected argument fp16_config to be type dictionary, but got {type(amp_config)}'
if self.amp_type == AMP_TYPE.TORCH:
# torch apex
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.APEX:
# apex amp
if amp_config is None:
amp_config = dict(opt_level='O2')
self.logger.warning(
'apex is deprecated, please consider using torch.cuda.amp instead.'
)
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.PARALLEL:
# use fp16 optimizer for tensor parallel training
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
else:
self.fp16 = False
self.amp_type = None
def initialize(self, model: nn.Module, optimizer: Optimizer):
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True
assert self.amp_type != AMP_TYPE.PARALLEL, \
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
if self.fp16:
if self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
elif self.amp_type == AMP_TYPE.APEX:
model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
return model, optimizer
def forward_backward_step(self,
data_iter: Iterable,
model: nn.Module,
criterion: nn.modules.loss._Loss,
optimizer: Optimizer = None,
forward_only: bool = False,
grad_accum_size: int = 1,
return_loss: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param model: Model for training and inference
:param criterion: Loss function for training
:param optimizer: Optimizer used for training
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param grad_accum_size: The number of iterations for gradient accumulation
:param return_loss: Loss will be returned if True
:type data_iter: Iterator
:type model: torch.nn.Module
:type criterion: torch.nn.modules.loss._Loss
:type optimizer: torch.optim.Optimizer
:type forward_only: bool, optional
:type grad_accum_size: int
:type return_loss: bool, optional
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
data, label = self.load_batch(data_iter)
loss = None
# forward
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
with torch_amp.autocast():
output = model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = criterion(*output, *label)
else:
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
data = convert_to_fp16(data)
output = model(*data)
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
output = convert_to_fp32(output)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = criterion(*output, *label)
loss /= grad_accum_size
if not forward_only:
# backward
if self.use_zero_level_2_3:
optimizer.backward(loss)
elif self.fp16:
if self.amp_type == AMP_TYPE.APEX:
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
elif self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.scale(loss).backward()
elif self.amp_type == AMP_TYPE.PARALLEL:
loss = optimizer.scale_loss(loss)
loss.backward()
# scale back to display the original value in logs
loss.div_(optimizer.grad_scaler.scale)
else:
loss.backward()
if return_loss:
return output, label, loss * grad_accum_size
else:
return output, None, None
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
# step optimizer
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
if grad_clipping > 0.0:
self._torch_amp_scaler.unscale_(optimizer)
clip_grad_norm_fp32(model.parameters(), grad_clipping)
self._torch_amp_scaler.step(optimizer)
self._torch_amp_scaler.update()
else:
if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
clip_grad_norm_fp32(model.parameters(), grad_clipping)
optimizer.step()

View File

@ -0,0 +1,61 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Iterable
import torch
import torch.nn as nn
from colossalai.engine import Engine
from torch.optim import Optimizer
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context
class NonPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def forward_backward_step(self,
engine: Engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:param engine: Model for training and inference
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param return_loss: Loss will be returned if True
:type engine: Iterator
:type data_iter: Iterator
:type forward_only: bool, optional
:type return_loss: bool, optional
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
data, label = self.load_batch(data_iter)
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
output = engine(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = engine.criterion(*output, *label)
if not forward_only:
engine.backward(loss)
if return_loss:
return output, label, loss
else:
return output, None, None

View File

@ -10,12 +10,12 @@ from torch import Tensor
from colossalai.communication import * from colossalai.communication import *
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, from colossalai.amp.naive_amp import NaiveAMPModel
ZeroRedundancyOptimizer_Level_3) from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16 from colossalai.amp import AMP_TYPE
from ..amp import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]): def squeeze(x: Union[Tensor, tuple, list]):
@ -28,32 +28,25 @@ def squeeze(x: Union[Tensor, tuple, list]):
class PipelineSchedule(BaseSchedule): class PipelineSchedule(BaseSchedule):
"""A helper schedule class for pipeline parallelism running environment. """A helper schedule class for pipeline parallelism running environment.
It uses non-interleaved 1F1B strategy. Other properties are similar as It uses non-interleaved 1F1B strategy. Other properties are similar as
:class:`NoPipelineSchedule`. :class:`NonPipelineSchedule`.
:param num_microbatches: The number of microbatches :param num_microbatches: The number of microbatches
:param amp_type: The type of automatic mixed precision :param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision :param amp_config: The configuration of automatic mixed procision
:param sync_data: If set to `True`, will sync data every batch over pipeline stages
:type num_microbatches: int :type num_microbatches: int
:type amp_type: AMP_TYPE :type amp_type: AMP_TYPE
:type amp_config: dict :type amp_config: dict
:type sync_data: bool
""" """
def __init__(self, def __init__(self,
num_microbatches, num_microbatches,
amp_type: AMP_TYPE = None, sync_data: bool = True):
amp_config: dict = None):
super().__init__() super().__init__()
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.data_sync = True # close after making sure data is identical self.sync_data = sync_data
# amp
# LSGL: amp_config is not used, but leave here for future extension
self.amp_type = amp_type
self.amp_config = amp_config
if self.amp_type is not None:
assert self.amp_type == AMP_TYPE.PARALLEL, 'We only support AMP_TYPE.PARALLEL for pipeline training for now'
def _move_to_device(self, data): def _move_to_device(self, data):
if isinstance(data, ( if isinstance(data, (
@ -67,30 +60,37 @@ class PipelineSchedule(BaseSchedule):
return data return data
def _sync_data(self): def _sync_data(self):
reqs = []
if gpc.is_first_rank(ParallelMode.PIPELINE): if gpc.is_first_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_global_rank() src_rank = gpc.get_global_rank()
dist.broadcast( reqs.append(dist.broadcast(
tensor=self.batch_data, tensor=self.batch_data,
src=src_rank, src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV) group=gpc.get_group(ParallelMode.PIPELINE_PREV),
) async_op=True
dist.broadcast( ))
reqs.append(dist.broadcast(
tensor=self.batch_label, tensor=self.batch_label,
src=src_rank, src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV) group=gpc.get_group(ParallelMode.PIPELINE_PREV),
) async_op=True
))
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
dist.broadcast( reqs.append(dist.broadcast(
tensor=self.batch_data, tensor=self.batch_data,
src=src_rank, src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT) group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
) async_op=True
dist.broadcast( ))
reqs.append(dist.broadcast(
tensor=self.batch_label, tensor=self.batch_label,
src=src_rank, src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT) group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
) async_op=True
))
for req in reqs:
req.wait()
# Pipeline schedule just puts data in memory # Pipeline schedule just puts data in memory
def load_batch(self, data_iter): def load_batch(self, data_iter):
@ -104,7 +104,7 @@ class PipelineSchedule(BaseSchedule):
assert batch_size % self.num_microbatches == 0, \ assert batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches" "Batch size should divided by the number of microbatches"
self.microbatch_size = batch_size // self.num_microbatches self.microbatch_size = batch_size // self.num_microbatches
if self.data_sync: if self.sync_data:
self._sync_data() self._sync_data()
def _get_data_slice(self, tensor): def _get_data_slice(self, tensor):
@ -116,21 +116,20 @@ class PipelineSchedule(BaseSchedule):
self.batch_pos += self.microbatch_size self.batch_pos += self.microbatch_size
return (data,), (label,) return (data,), (label,)
def initialize(self, model, optimizer): def pre_processing(self, engine):
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
raise TypeError( raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
) )
# LSG: set default dtype to fp16 for communication # LSG: set default dtype to fp16 for communication
if self.amp_type == AMP_TYPE.PARALLEL: if isinstance(engine.model, NaiveAMPModel):
torch.set_default_dtype(torch.half) torch.set_default_dtype(torch.half)
self.logger.info( self.logger.warning(
'default tensor dtype is set to torch.half for fp16 training', 'default tensor dtype is set to torch.half for fp16 training',
ranks=[0]) ranks=[0])
def forward_step(self, model, criterion, input_tensor, return_tensors, def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
grad_accum_size, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used. is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
@ -138,17 +137,16 @@ class PipelineSchedule(BaseSchedule):
if input_tensor is None: if input_tensor is None:
input_tensor, label = self.load_micro_batch() input_tensor, label = self.load_micro_batch()
if self.amp_type == AMP_TYPE.PARALLEL:
input_tensor = convert_to_fp16(input_tensor)
input_tensor = squeeze(input_tensor) input_tensor = squeeze(input_tensor)
output_tensor = model(input_tensor) output_tensor = engine(input_tensor)
output_tensor = squeeze(output_tensor) output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss: if return_loss:
input_tensor, label = self.load_micro_batch() input_tensor, label = self.load_micro_batch()
loss_reduced = criterion(output_tensor, *label) \ loss_reduced = engine.criterion(output_tensor, *label) \
/ (self.num_microbatches * grad_accum_size) / self.num_microbatches
return_tensors.append( return_tensors.append(
tuple((output_tensor, label[0], loss_reduced))) tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced return loss_reduced
@ -159,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
else: else:
return output_tensor return output_tensor
def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad): def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the """Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor. output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage). Returns the gradients with respect to the input tensor (None if first stage).
@ -171,9 +169,10 @@ class PipelineSchedule(BaseSchedule):
input_tensor.retain_grad() input_tensor.retain_grad()
# Backward pass. # Backward pass.
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL: if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor) engine.backward(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) else:
engine.backward_by_grad(output_tensor, output_tensor_grad)
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
input_tensor_grad = None input_tensor_grad = None
@ -183,12 +182,9 @@ class PipelineSchedule(BaseSchedule):
return input_tensor_grad return input_tensor_grad
def forward_backward_step(self, def forward_backward_step(self,
engine,
data_iter, data_iter,
model,
criterion,
optimizer=None,
forward_only=False, forward_only=False,
grad_accum_size: int = 1,
return_loss=True): return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise. Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -226,9 +222,8 @@ class PipelineSchedule(BaseSchedule):
ft_shape = recv_tensor_meta(ft_shape) ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape) input_tensor = recv_forward(ft_shape)
output_tensor = self.forward_step( output_tensor = self.forward_step(
model, criterion, engine, input_tensor, return_tensors,
input_tensor, return_tensors, return_loss=return_loss
grad_accum_size, return_loss=return_loss
) )
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape bt_shape = output_tensor.shape
@ -252,9 +247,8 @@ class PipelineSchedule(BaseSchedule):
last_iteration = (i == (num_microbatches_remaining - 1)) last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step( output_tensor = self.forward_step(
model, criterion, engine, input_tensor, return_tensors,
input_tensor, return_tensors, return_loss=return_loss
grad_accum_size, return_loss=return_loss
) )
if forward_only: if forward_only:
send_forward(output_tensor) send_forward(output_tensor)
@ -276,7 +270,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step( input_tensor_grad = self.backward_step(
optimizer, engine,
input_tensor, output_tensor, input_tensor, output_tensor,
output_tensor_grad output_tensor_grad
) )
@ -297,7 +291,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor_grad = recv_backward(bt_shape) output_tensor_grad = recv_backward(bt_shape)
input_tensor_grad = self.backward_step( input_tensor_grad = self.backward_step(
optimizer, engine,
input_tensor, output_tensor, input_tensor, output_tensor,
output_tensor_grad output_tensor_grad
) )
@ -309,11 +303,8 @@ class PipelineSchedule(BaseSchedule):
output, label, loss = tuple(map(list, zip(*return_tensors))) output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0), return (torch.cat(output, dim=0),
torch.cat(label, dim=0), torch.cat(label, dim=0),
sum(loss) * grad_accum_size) sum(loss))
else: else:
return tuple((torch.cat(return_tensors, dim=0), None, None)) return tuple((torch.cat(return_tensors, dim=0), None, None))
else: else:
return tuple((None, None, None)) return tuple((None, None, None))
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
optimizer.step()

View File

@ -1,27 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union, List
from torch import Tensor
def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.half()
elif isinstance(data, (list, tuple)):
ret = [val.half() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.float()
elif isinstance(data, (list, tuple)):
ret = [val.float() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret

View File

@ -3,377 +3,326 @@
import argparse import argparse
import pprint import pprint
import random import os
from pathlib import Path from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from typing import Callable, Iterable, Optional, Union
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader import torch.nn as nn
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule from pathlib import Path
from typing import Iterable, Union, Optional, Tuple, List, Dict
from colossalai.amp import convert_to_amp, AMP_TYPE
from colossalai.context import Config, ParallelMode, ConfigException
from colossalai.core import global_context as gpc
from colossalai.engine import Engine from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger, init_global_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import DataParallelSampler from colossalai.utils import (accumulate_gradient, get_current_device,
from colossalai.nn.model.base_model import BaseModel sync_model_param_in_dp, is_using_ddp, is_using_pp)
from .builder import (ModelInitializer, build_dataset, build_loss, from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
build_model, build_optimizer, from colossalai.builder.builder import build_gradient_handler
build_optimizer_wrapper, build_schedule) from torch.optim.optimizer import Optimizer
from .context import Config, ParallelMode from torch.optim.lr_scheduler import _LRScheduler
from .core import global_context as gpc from torch.utils.data import DataLoader
from .utils import get_current_device, sync_model_param_in_dp from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel as DDP
def parse_args(): def get_default_parser():
'''Reads user command line and uses an argument parser to parse the input arguments. '''Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
:return: call the parse arguments function :return: returns the parser with the default arguments, the user may add customized arguments into this parser
:rtype: Namespace :rtype: Namespace
''' '''
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file') parser.add_argument('--config', type=str, help='path to the config file')
parser.add_argument('--host', parser.add_argument('--host',
type=str, type=str,
default=None,
help='the master address for distributed training') help='the master address for distributed training')
parser.add_argument('--port', parser.add_argument('--port',
type=str, type=int,
default=None,
help='the master port for distributed training') help='the master port for distributed training')
parser.add_argument('--world_size', type=int, help='world size for ') parser.add_argument('--world_size', type=int, help='world size for distributed training')
parser.add_argument('--rank', type=int, help='rank for the default process group')
parser.add_argument('--local_rank', parser.add_argument('--local_rank',
type=int, type=int,
help='rank for the default process group') help='local rank on the node')
parser.add_argument('--backend', parser.add_argument('--backend',
type=str, type=str,
default='nccl', default='nccl',
help='backend for torch.distributed') help='backend for distributed communication')
return parser.parse_args() return parser
def init_dist(config: Union[str, dict] = None, def launch(config: Union[str, Path, Config, Dict],
local_rank: int = None, rank: int,
world_size: int = None, world_size: int,
host: str = None, host: str,
port: str = None, port: int,
backend: str = None): backend: str = 'nccl',
local_rank: int = None,
seed: int = 1024,
verbose: bool = True):
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given. '''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
Then initialize and set distributed environment by calling global_context's functions. Then initialize and set distributed environment by calling global_context's functions.
:param config: config file or config file path are both acceptable :param config: config file or config file path are both acceptable
:type config: Union[str, dict], optional :type config: Union[str, dict, Config]
:param local_rank: rank for the default process group, defaults to None :param rank: rank for the default process group
:type rank: int
:param world_size: world size of the default process group
:type world_size: int
:param host: the master address for distributed training
:type host: str
:param port: the master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
:param local_rank: rank for the process on the node and is used to set the default CUDA device,
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically
:type local_rank: int, optional :type local_rank: int, optional
:param world_size: world size of GPUs, defaults to None
:type world_size: int, optional
:param host: the master address for distributed training, defaults to None
:type host: str, optional
:param port: the master port for distributed training, defaults to None
:type port: str, optional
:param backend: backend for torch.distributed, defaults to None
:type backend: str, optional
:raises Exception: raise exception when config type is wrong :raises Exception: raise exception when config type is wrong
''' '''
args = [config, local_rank, world_size, host, port, backend] gpc.verbose = verbose
arg_given = [arg is not None for arg in args]
if not all(arg_given):
args = parse_args()
if config is None:
config = args.config
if local_rank is None:
local_rank = args.local_rank
if world_size is None:
world_size = args.world_size
if host is None:
host = args.host
if port is None:
port = args.port
if backend is None:
backend = args.backend
args = Config(
dict(config=config,
host=host,
port=port,
world_size=world_size,
local_rank=local_rank,
backend=backend))
# set distributed settings
dist_args = Config(
dict(local_rank=args.local_rank,
world_size=args.world_size,
backend=args.backend))
gpc.set_dist_args(dist_args)
# set config # set config
if isinstance(args.config, dict): assert isinstance(config, (Config, str, Path, dict)), \
cfg = args.config f'expected argument config to be Config, str or Path, but got {type(config)}'
elif isinstance(args.config, (str, Path)): if not isinstance(config, Config) and isinstance(config, dict):
cfg = Config.from_file(args.config) config = Config(config)
else: if isinstance(config, (str, Path)):
raise Exception('Config type error: {}'.format(type(args.config))) config = Config.from_file(config)
gpc.load_config(cfg) gpc.load_config(config)
# init dist groups # init default process group
gpc.init_global_dist(args.host, args.port) gpc.init_global_dist(rank, world_size, backend, host, port)
# init process groups for different parallel modes from config
gpc.init_parallel_groups() gpc.init_parallel_groups()
# init dist logger
init_global_dist_logger()
# set cuda device # set cuda device
if torch.cuda.is_available(): if torch.cuda.is_available():
gpc.set_device() # if local rank is not given, calculate automatically
gpc.set_device(local_rank)
gpc.set_seed(seed)
if verbose:
logger = get_dist_logger()
logger.info(f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs): def launch_from_slurm(config: Union[str, Path, Config, Dict],
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) host: str,
port: int,
.. note: when pipeline parallel is enabled, shuffle cannot be True backend: str = 'nccl',
as it will result in mismatch between input data on the 1st seed: int = 1024,
stage and label on the last stage verbose: bool = True):
rank = int(os.environ['SLURM_PROCID'])
:param dataset: a :class:utils.data.dataset dataset world_size = int(os.environ['SLURM_NPROCS'])
:param seed: random worker seed, defaults to 1024 launch(config=config,
:type seed: int, optional rank=rank,
:param add_sampler_if_possible: [description], defaults to False world_size=world_size,
:type add_sampler_if_possible: bool, optional host=host,
:return: a :class:utils.data.dataset dataloader port=port,
:rtype: torch.utils.data.dataset backend=backend,
''' seed=seed,
_kwargs = kwargs.copy() verbose=verbose)
if 'shuffle' in _kwargs:
shuffle = _kwargs.pop('shuffle')
else:
shuffle = False
if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
sampler = DataParallelSampler(dataset, shuffle=shuffle)
else:
sampler = None
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
if sampler is None:
return DataLoader(dataset,
worker_init_fn=seed_worker,
shuffle=shuffle,
**_kwargs)
else:
return DataLoader(dataset,
sampler=sampler,
worker_init_fn=seed_worker,
**_kwargs)
def initialize(config: Union[str, dict] = None, def launch_from_openmpi(config: Union[str, Path, Config, Dict],
local_rank: int = None, host: str,
world_size: int = None, port: int,
host: str = None, backend: str = 'nccl',
port: str = None, seed: int = 1024,
backend: str = None, verbose: bool = True):
train_dataloader: Optional[Union[Iterable, Callable]] = None, rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
test_dataloader: Optional[Union[Iterable, Callable]] = None, local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
launch(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def launch_from_torch(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
launch(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def initialize(model: Union[nn.Module, List[nn.Module]],
optimizer: Union[Optimizer, List[Optimizer]],
criterion: Union[_Loss, List[_Loss]],
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None,
verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader]: ) -> Tuple[Engine, DataLoader, DataLoader]:
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config). ''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
:param config: config file or config file path are both acceptable :param model: your model instance
:type config: Union[str, dict], optional :type model: a single or a list of ``torch.nn.Module`` objects
:param local_rank: rank for the default process group, defaults to None :param optimizer: your optimizer instance
:type local_rank: int, optional :type optimizer: a single or a list of ``torch.optim.optimizer.Optimizer`` objects
:param world_size: world size of GPUs, defaults to None :param criterion: your criterion instance
:type world_size: int, optional :type criterion: a single or a list of ``torch.nn.modules.loss._Loss`` objects
:param host: the master address for distributed training, defaults to None :param train_dataloader: dataloaders for training data
:type host: str, optional :type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
:param port: the master port for distributed training, defaults to None :param train_dataloader: dataloaders for testing data
:type port: str, optional :type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
:param backend: backend for torch.distributed, defaults to None :return: (engine, criterion, train_dataloader, test_dataloader)
:type backend: str, optional
:param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
:return: (engine, train_dataloader, test_dataloader, criterion)
:rtype: tuple :rtype: tuple
''' '''
# initialize distributed environment # get logger
init_dist(config=config, logger = get_dist_logger()
local_rank=local_rank, gpc.verbose = verbose
world_size=world_size,
host=host,
port=port,
backend=backend)
# init logger # get config from gpc
logger = get_global_dist_logger() config = gpc.config
logger.info(f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
# print config # print config
logger.info(f"\n========== Your Config ========\n" if verbose:
f"{pprint.pformat(gpc.config)}\n" logger.info(f"\n========== Your Config ========\n"
f"================================", ranks=[0]) f"{pprint.pformat(gpc.config)}\n"
f"================================\n", ranks=[0])
# cudnn # cudnn
cudnn_benchmark = gpc.config.get('cudnn_benchmark', True) cudnn_benchmark = config.get('cudnn_benchmark', True)
cudnn_deterministic = gpc.config.get('cudnn_deterministic', False) cudnn_deterministic = config.get('cudnn_deterministic', False)
torch.backends.cudnn.benchmark = cudnn_benchmark torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.deterministic = cudnn_deterministic torch.backends.cudnn.deterministic = cudnn_deterministic
logger.info( if verbose:
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) logger.info(
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# set seed, cuda seed is only set when cuda is avail # first sync model across dp ranks
gpc.set_seed() model.to(get_current_device())
# return_items = list()
# check fp16 and zero
should_convert_model_to_half = False
should_wrap_fp16_optimizer = False
should_wrap_zero_optimizer_level_2_3 = False
if hasattr(gpc.config, 'fp16'):
fp16_mode = gpc.config.fp16.mode
if fp16_mode == AMP_TYPE.PARALLEL:
should_convert_model_to_half = True
should_wrap_fp16_optimizer = True
if hasattr(gpc.config, 'zero'):
should_wrap_zero_optimizer_level_2_3 = True
zero_type = gpc.config.zero.type
if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']:
should_convert_model_to_half = True
assert not should_wrap_fp16_optimizer, \
'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3'
# build model
logger.info('Building model ...', ranks=[0])
assert hasattr(
gpc.config, 'model'), "Build error: configuration 'model' is missing"
if gpc.pipeline_parallel_size > 1:
model = ModelInitializer(gpc.config.model, 1, verbose=True)
model = model.model_initialize()
else:
model = build_model(gpc.config.model)
if isinstance(model, BaseModel):
model.build_from_cfg()
model = model.to(get_current_device())
sync_model_param_in_dp(model) sync_model_param_in_dp(model)
logger.info('Model is created', ranks=[0])
if should_convert_model_to_half: # check amp and zero
model = model.half() fp16_cfg = gpc.config.get('fp16', None)
logger.info("Model is cast to fp16", ranks=[0]) zero_cfg = gpc.config.get('zero', None)
# training data if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
if callable(train_dataloader): raise ConfigException(
logger.info( "It is not allowed to set fp16 and zero configuration in your config file at the same time")
f'Build train data loader from {train_dataloader}', ranks=[0])
train_dataloader = train_dataloader()
if train_dataloader is None and hasattr(gpc.config, 'train_data'):
logger.info('Preparing data ...', ranks=[0])
# assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing."
train_dataset = build_dataset(gpc.config.train_data.dataset)
logger.info('Train dataset is ready.', ranks=[0])
train_dataloader = get_dataloader(train_dataset, # initialize amp
gpc.config.get('seed', 1024), amp_mode = None
True, if fp16_cfg is not None and fp16_cfg.mode is not None:
**gpc.config.train_data.dataloader, cfg_ = fp16_cfg.copy()
) amp_mode = cfg_.pop('mode')
logger.info( model, optimizer, criterion = convert_to_amp(model=model,
f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0]) optimizer=optimizer,
criterion=criterion,
mode=amp_mode,
amp_config=cfg_)
if callable(test_dataloader): if zero_cfg is not None:
logger.info( cfg_ = zero_cfg.copy()
f'Build test data loader from {test_dataloader}', ranks=[0]) level = cfg_.pop('level')
test_dataloader = test_dataloader() model, optimizer = convert_to_zero(model=model,
# testing data, allowed to be None optimizer=optimizer,
if test_dataloader is None and hasattr(gpc.config, 'test_data'): level=level,
test_dataset = build_dataset(gpc.config.test_data.dataset) zero_config=cfg_
test_dataloader = get_dataloader( )
test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader)
logger.info(
f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0])
# build loss function # gradient handler
assert hasattr(gpc.config, 'loss'), \ gradient_handler_cfg = gpc.config.get('gradient_handler', None)
'Build error: configuration \'loss\' is missing.' if gradient_handler_cfg is None:
criterion = build_loss(gpc.config.loss) # if gradient handler is not specified in the configuration file,
logger.info('Loss function is created', ranks=[0]) # check in the following order
# 1. if optimizer is ZERO, then use zero grad handler
# build optimizer # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
assert hasattr(gpc.config, 'optimizer'), \ # 3. if using pipeline and dp size larger than 1, use data parallel grad handler
"Build error: configuration 'optimizer' is missing." if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
optim_type = gpc.config.optimizer.type ZeroRedundancyOptimizer_Level_3)):
is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer' gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
if is_pytorch_native_zero_level_1: if verbose:
original_cfg_copy = gpc.config.optimizer.copy() logger.info(
original_cfg_copy.pop('type') "Training with zero is detected, ZeROGradientHandler is automatically "
cfg = dict(type=optim_type, process_group=gpc.get_group( "added even though not specified in the configuration",
ParallelMode.DATA), **original_cfg_copy) ranks=[0])
optimizer = build_optimizer(cfg, model) elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
elif is_using_ddp():
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
if verbose:
logger.info(
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
else: else:
optimizer = build_optimizer(gpc.config.optimizer, model) if not isinstance(gradient_handler_cfg, list):
raise ConfigException(
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
if should_wrap_zero_optimizer_level_2_3: if gradient_handler_cfg is None:
optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model) gradient_handlers = None
if verbose and not isinstance(model, DDP):
if should_wrap_fp16_optimizer: logger.warning(
# replace the field mode with type "No PyTorch DDP or gradient handler is set up, please make sure you do not need "
fp16_cfg = gpc.config.fp16.copy() "to all-reduce the gradients after a training step.",
amp_type = fp16_cfg.pop('mode') ranks=[0])
assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL'
fp16_cfg['type'] = 'FP16Optimizer'
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
logger.info('Optimizer is created', ranks=[0])
# build schedule and engine
if hasattr(gpc.config, 'fp16'):
amp_type = gpc.config.fp16.mode
amp_cfg = gpc.config.fp16.copy()
amp_cfg.pop('mode')
else: else:
amp_type = None gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
amp_cfg = None
engine_cfg = gpc.config.get('engine', dict()) # check if optimizer is ColossalaiOptimizer
schedule_cfg = engine_cfg.pop('schedule', None) if not isinstance(optimizer, (ColossalaiOptimizer, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
optimizer = ColossalaiOptimizer(optim=optimizer)
schedule_type = None # gradient accumulation
if schedule_cfg is not None: grad_accum_size = gpc.config.get('gradient_accumulation', None)
schedule_type = schedule_cfg.get('type', None) if grad_accum_size is not None:
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model,
optimizer=optimizer,
dataloader=train_dataloader,
accumulate_size=grad_accum_size,
gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler)
if schedule_type is not None: # clip grad norm
# run customized schedule clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
schedule_cfg['amp_type'] = amp_type if clip_grad_norm > 0:
schedule_cfg['amp_config'] = amp_cfg if zero_cfg is not None:
schedule = build_schedule(schedule_cfg) raise ConfigException(
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: "clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
assert schedule_cfg is not None, \ elif fp16_cfg is not None and fp16_cfg.mode == AMP_TYPE.NAIVE:
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training" raise ConfigException(
schedule = PipelineSchedule( "clip_grad_norm should be specified with AMP_TYPE.NAIVE, you should specify clip_grad in fp16 configuration")
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
else:
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
engine = Engine( engine = Engine(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
step_schedule=schedule, gradient_handlers=gradient_handlers,
**gpc.config.get('engine', dict()) clip_grad_norm=clip_grad_norm
) )
return engine, train_dataloader, test_dataloader return engine, train_dataloader, test_dataloader, lr_scheduler

View File

@ -1,26 +1,10 @@
from colossalai.core import global_context as gpc
from .logging import DistributedLogger from .logging import DistributedLogger
__all__ = ['get_global_dist_logger', 'get_dist_logger', 'DistributedLogger', 'init_global_dist_logger'] __all__ = ['get_dist_logger', 'DistributedLogger']
_GLOBAL_LOGGER: DistributedLogger = None
def get_dist_logger(name, level='INFO', root_path: str = None, mode='a'): def get_dist_logger(name='root'):
return DistributedLogger(name=name, level=level, root_path=root_path, mode=mode) """Get logger instance based on name. The DistributedLogger will create singleton instances,
which means that only one logger instance is created per name.
"""
def get_global_dist_logger(): return DistributedLogger.get_instance(name=name)
assert _GLOBAL_LOGGER is not None, 'Global distributed logger is not initialized'
return _GLOBAL_LOGGER
def init_global_dist_logger():
rank = gpc.get_global_rank()
if hasattr(gpc.config, 'logging'):
logger = get_dist_logger(name=f'rank_{rank}', **gpc.config.logging)
else:
logger = get_dist_logger(name=f'rank_{rank}', level='INFO')
global _GLOBAL_LOGGER
assert _GLOBAL_LOGGER is None, 'Global distributed logger has already been initialized'
_GLOBAL_LOGGER = logger

View File

@ -1,11 +1,13 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import colossalai
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s' _FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT) logging.basicConfig(level=logging.INFO, format=_FORMAT)
@ -16,40 +18,92 @@ class DistributedLogger:
:param name: The name of the logger :param name: The name of the logger
:type name: str :type name: str
:param level: The threshold for the logger. Logging messages which are less severe than `level`
will be ignored
:type level: str
:param root_path: The root path where logs are stored
:type root_path: str, optional
:param mode: The mode that the file is opened in. Defaults to 'a'
:type mode: str, optional
""" """
def __init__(self, name, level='INFO', root_path: str = None, mode='a'): __instances = dict()
self._logger = logging.getLogger(name)
@staticmethod
def get_instance(name: str):
"""Get the unique single logger instance based on name.
:param name: The name of the logger
:type name: str
:return: a DistributedLogger object
:rtype: DistributedLogger
"""
if name in DistributedLogger.__instances:
return DistributedLogger.__instances[name]
else:
logger = DistributedLogger(name=name)
return logger
def __init__(self, name):
if name in DistributedLogger.__instances:
raise Exception('Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
else:
self._name = name
self._logger = logging.getLogger(name)
DistributedLogger.__instances[name] = self
@staticmethod
def _check_valid_logging_level(level: str):
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
def set_level(self, level: str):
"""Set the logging level
:param level: can only be INFO, DEBUG, WARNING and ERROR
:type level: str
"""
self._check_valid_logging_level(level)
self._logger.setLevel(getattr(logging, level)) self._logger.setLevel(getattr(logging, level))
if root_path is not None: def log_to_file(self,
log_root_path = Path(root_path) path: Union[str, Path],
# create path if not exists mode: str = 'a',
log_root_path.mkdir(parents=True, exist_ok=True) level: str = 'INFO',
log_path = log_root_path.joinpath(f'{name}.log') suffix: str = None):
file_handler = logging.FileHandler(log_path, mode) """Save the logs to file
file_handler.setLevel(getattr(logging, level)) :param path: the file to save the log
formatter = logging.Formatter(_FORMAT) :type path: a string or pathlib.Path object
file_handler.setFormatter(formatter) :param mode: the mode to write log into the file
self._logger.addHandler(file_handler) :type mode: str
:param level: can only be INFO, DEBUG, WARNING and ERROR
:type level: str
"""
assert isinstance(path, (str, Path)), \
f'expected argument path to be type str or Path, but got {type(path)}'
self._check_valid_logging_level(level)
if isinstance(path, str):
path = Path(path)
# set the default file name if path is a directory
if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
rank = 0
else:
rank = colossalai.core.global_context.get_global_rank()
if suffix is not None:
log_file_name = f'rank_{rank}_{suffix}.log'
else:
log_file_name = f'rank_{rank}.log'
path = path.joinpath(log_file_name)
# add file handler
file_handler = logging.FileHandler(path, mode)
file_handler.setLevel(getattr(logging, level))
formatter = logging.Formatter(_FORMAT)
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
if ranks is None: if ranks is None:
getattr(self._logger, level)(message) getattr(self._logger, level)(message)
else: else:
local_rank = gpc.get_local_rank(parallel_mode) local_rank = colossalai.core.global_context.get_local_rank(parallel_mode)
if local_rank in ranks: if local_rank in ranks:
getattr(self._logger, level)(message) getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores an info log message. """Log an info message.
:param message: :param message:
:type message: :type message:
@ -61,7 +115,7 @@ class DistributedLogger:
self._log('info', message, parallel_mode, ranks) self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores a warning log message. """Log a warning message.
:param message: The message to be logged :param message: The message to be logged
:type message: str :type message: str
@ -73,7 +127,7 @@ class DistributedLogger:
self._log('warning', message, parallel_mode, ranks) self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores a debug log message. """Log a debug message.
:param message: The message to be logged :param message: The message to be logged
:type message: str :type message: str
@ -85,7 +139,7 @@ class DistributedLogger:
self._log('debug', message, parallel_mode, ranks) self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores an error log message. """Log an error message.
:param message: The message to be logged :param message: The message to be logged
:type message: str :type message: str

View File

@ -1,4 +1,3 @@
from .data import *
from .layer import * from .layer import *
from .loss import * from .loss import *
from .lr_scheduler import * from .lr_scheduler import *

View File

@ -1,3 +0,0 @@
from .caltech101_dataset import Caltech101Dataset
from .cifar10_dataset import CIFAR10Dataset
from .sampler import *

View File

@ -1,14 +0,0 @@
import numpy as np
def pil_img_to_numpy(pil_img):
"""convert a PIL image to numpy nd-array
:param pil_img: a PIL image
:type pil_img: PIL.Image
:return: a nd-array
:rtype: numpy.ndarray
"""
np_img = np.array(pil_img)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img

View File

@ -1,17 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from colossalai.builder import build_transform
class BaseDataset(Dataset, ABC):
def __init__(self, transform_pipeline: list):
transform_list = [build_transform(cfg) for cfg in transform_pipeline]
transform = transforms.Compose(transform_list)
self._transform_pipeline = transform

View File

@ -1,43 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import Caltech101
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class Caltech101Dataset(BaseDataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = Caltech101(
transform=self._transform_pipeline, *args, **kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)

View File

@ -1,44 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import CIFAR10
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class CIFAR10Dataset(BaseDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = CIFAR10(transform=self._transform_pipeline,
*args,
**kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)

View File

@ -1,4 +0,0 @@
from .base_sampler import BaseSampler
from .data_parallel_sampler import DataParallelSampler
__all__ = ['BaseSampler', 'DataParallelSampler']

33
colossalai/nn/init.py Normal file
View File

@ -0,0 +1,33 @@
import math
from torch import Tensor
from torch.nn import init as init
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
if init_method == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(tensor, -a, a)
elif init_method == 'jax_embed':
std = math.sqrt(1.0 / fan_in)
init.trunc_normal_(tensor, std=std / .87962566103423978)
elif init_method == 'zero':
init.zeros_(tensor)
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
if init_method == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
init.normal_(tensor, std=1e-6)
elif init_method == 'jax_embed':
init.trunc_normal_(tensor, std=.02)
elif init_method == 'zero':
init.zeros_(tensor)

View File

@ -1,9 +1,8 @@
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import * from .parallel_1d import *
from .parallel_2d import * from .parallel_2d import *
from .parallel_2p5d import * from .parallel_2p5d import *
from .parallel_3d import * from .parallel_3d import *
from .parallel_sequence import * from .parallel_sequence import *
from .parallel_vision_transformer import * from .non_parallel_layers import *
from .vanilla_resnet import *
from .vanilla_vision_transformer import *
from .wrapper import * from .wrapper import *

View File

@ -2,40 +2,14 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import collections.abc
from itertools import repeat
import numpy as np
from colossalai.utils.common import print_rank_0
import torch import torch
from torch import Tensor from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from torch import nn
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from torch import Tensor, nn
from colossalai.constants import IS_TENSOR_PARALLEL
def divide(numerator, denominator):
""" only allow exact division """
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
return numerator // denominator
def gelu(x: Tensor) -> Tensor:
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
def set_tensor_parallel_attribute(param):
if not hasattr(param, IS_TENSOR_PARALLEL):
setattr(param, IS_TENSOR_PARALLEL, True)
class CheckpointModule(nn.Module): class CheckpointModule(nn.Module):
@ -44,15 +18,15 @@ class CheckpointModule(nn.Module):
self.checkpoint = checkpoint self.checkpoint = checkpoint
self._use_checkpoint = checkpoint self._use_checkpoint = checkpoint
def _forward(self, *args): def _forward(self, *args, **kwargs):
raise NotImplementedError( raise NotImplementedError(
'CheckpointModule should implement _forward method instead of origin forward') 'CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args): def forward(self, *args, **kwargs):
if self._use_checkpoint: if self._use_checkpoint:
return checkpoint(self._forward, *args) return checkpoint(self._forward, *args, **kwargs)
else: else:
return self._forward(*args) return self._forward(*args, **kwargs)
def train(self, mode: bool = True): def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint self._use_checkpoint = self.checkpoint
@ -61,3 +35,38 @@ class CheckpointModule(nn.Module):
def eval(self): def eval(self):
self._use_checkpoint = False self._use_checkpoint = False
return super().eval() return super().eval()
def divide(numerator, denominator):
""" only allow exact division """
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
return numerator // denominator
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
def set_tensor_parallel_attribute_by_size(param, size):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))
def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, num_partitions)
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)

View File

@ -0,0 +1,35 @@
# adapted from Megatron-LM
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
import torch
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply

View File

@ -0,0 +1,8 @@
from ._vit import (ViTBlock, VanillaViTAttention, VanillaViTBlock, VanillaViTDropPath,
VanillaViTHead, VanillaViTMLP, VanillaViTPatchEmbedding)
__all__ = [
'ViTBlock', 'VanillaViTAttention', 'VanillaViTBlock', 'VanillaViTDropPath',
'VanillaViTHead', 'VanillaViTMLP', 'VanillaViTPatchEmbedding'
]

View File

@ -1,23 +1,47 @@
import collections.abc #!/usr/bin/env python
from itertools import repeat # -*- encoding: utf-8 -*-
import torch import torch
from torch import nn as nn from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from .._common_utils import to_2tuple
# From PyTorch internals @LAYERS.register_module
def _ntuple(n): class ViTBlock(nn.Module):
def parse(x): """Vision Transformer block
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse :param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
to_2tuple = _ntuple(2) def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@LAYERS.register_module @LAYERS.register_module

View File

@ -1,5 +1,11 @@
from .layers import Linear1D_Col, Linear1D_Row from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D
from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHead
__all__ = [ __all__ = [
'Linear1D_Col', 'Linear1D_Row', 'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D',
'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead'
] ]

View File

@ -0,0 +1,34 @@
import torch
try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None

View File

@ -0,0 +1,220 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide, ACT2FN
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
@LAYERS.register_module
class TransformerMLP1D(ParallelLayer):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super(TransformerMLP1D, self).__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
bias=not skip_bias_add,
dtype=dtype,
gather_output = False,
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
bias=not skip_bias_add,
dtype=dtype,
parallel_input = True,
)
self.dropout = nn.Dropout(dropout_prob)
# self.layernorm = LayerNorm1D(in_features, dtype=dtype)
self.layernorm = nn.LayerNorm(in_features, dtype=dtype)
def forward(self, x):
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention1D(ParallelLayer):
"""Self attention layer for 1D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
# need to re-enable torch grad to enable fused optimization.
# self.layernorm = LayerNorm1D(
# hidden_size,
# dtype=dtype)
self.layernorm = nn.LayerNorm(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer1D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention1D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP1D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output

View File

@ -13,3 +13,6 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size) per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)

View File

@ -0,0 +1,411 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from colossalai import context
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from .layers import Linear1D_Col, Linear1D_Row
from ..base_layer import ParallelLayer
from .._common_utils import to_2tuple
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
class ViTMLP1D(ParallelLayer):
"""MLP layer for 1D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
skip_bias_add: bool = False,
weight_init='torch'
):
super().__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.skip_bias_add = skip_bias_add
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
dtype=dtype,
gather_output=False,
skip_bias_add=skip_dense_1_add_bias,
init_weight=weight_init,
init_bias=weight_init
)
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention1D(ParallelLayer):
"""Self-attention layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
weight_init='torch'
):
super().__init__()
self.hidden_size = hidden_size
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_bias = 'zero'
else:
init_bias = weight_init
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init,
init_bias=init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead1D(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
weight_init='torch'
):
super().__init__()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_weight = 'zero'
init_bias = 'zero'
else:
init_weight = weight_init
init_bias = weight_init
self.linear = Linear1D_Col(
hidden_size,
num_classes,
dtype=dtype,
gather_output=True,
init_weight=init_weight,
init_bias=init_bias
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTHead(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
):
super().__init__()
self.linear = nn.Linear(
hidden_size,
num_classes,
dtype=dtype
)
self._broadcast_linear_params()
def _broadcast_linear_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.linear.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.linear.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding1D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size
)
if weight_init == 'jax':
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
# sync
self._broadcast_conv_params()
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.proj.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.proj.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTTokenFuser1D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.empty(
1, self.num_patches + 1, self.embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=.02)
# move to cuda before broadcast
self.to(get_current_device())
dist.broadcast(self.pos_embed,
src=gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
self.pos_drop = nn.Dropout(p=drop_rate)
def forward(self, x: Tensor) -> Tensor:
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x.contiguous()

View File

@ -1,24 +1,30 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math
import numbers
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch import Tensor from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from typing import Tuple from typing import Tuple
import importlib
from colossalai.context.parallel_mode import ParallelMode from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .._common_utils import divide from ._operation import FusedLayerNormAffineFunction1D
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \ from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward split_forward_gather_backward
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
@LAYERS.register_module
class Linear1D_Col(ParallelLayer): class Linear1D_Col(ParallelLayer):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
@ -44,23 +50,29 @@ class Linear1D_Col(ParallelLayer):
output_size: int, output_size: int,
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
gather_output: bool = False): gather_output: bool = False,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
super().__init__() super().__init__()
# Keep input parameters # Keep input parameters
self.input_size = in_features self.in_features = in_features
self.output_size = output_size self.out_features = output_size
self.gather_output = gather_output self.gather_output = gather_output
self.skip_bias_add = not bias self.skip_bias_add = skip_bias_add
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) if skip_bias_add and not bias:
self.output_size_per_partition = divide(output_size, world_size) raise ValueError('cannot skip bias addition if bias is None')
self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size)
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size, self.output_size_per_partition, self.in_features,
**factory_kwargs)) **factory_kwargs))
if bias: if bias:
@ -72,6 +84,45 @@ class Linear1D_Col(ParallelLayer):
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Set up backprop all-reduce. # Set up backprop all-reduce.
@ -104,7 +155,7 @@ class Linear1D_Row(ParallelLayer):
:type bias: bool, optional :type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None :param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
:param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False :param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional :type parallel_input: bool, optional
""" """
@ -113,7 +164,10 @@ class Linear1D_Row(ParallelLayer):
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
parallel_input: bool = False parallel_input: bool = False,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
): ):
super().__init__() super().__init__()
@ -121,11 +175,13 @@ class Linear1D_Row(ParallelLayer):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.parallel_input = parallel_input self.parallel_input = parallel_input
self.skip_bias_add = not bias self.skip_bias_add = skip_bias_add
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
self.input_size_per_partition = divide(in_features, world_size)
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
@ -146,9 +202,46 @@ class Linear1D_Row(ParallelLayer):
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def reset_parameters(self) -> None: def reset_parameters(self, init_weight, init_bias) -> None:
init.xavier_normal_(self.weight) assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
dist.broadcast(self.bias,
src=gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce. # Set up backprop all-reduce.
@ -163,4 +256,29 @@ class Linear1D_Row(ParallelLayer):
if not self.skip_bias_add: if not self.skip_bias_add:
output = output + self.bias output = output + self.bias
return output return output
else:
return output, self.bias
@LAYERS.register_module
class MixedFusedLayerNorm1D(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm1D, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)

View File

@ -20,7 +20,6 @@ def matmul_2d(a,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL, col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
): ):
"""Matrix multiplication for 2D parallelism """Matrix multiplication for 2D parallelism
:param a: matrix :math:`A` :param a: matrix :math:`A`
:type a: torch.tensor :type a: torch.tensor
:param b: matrix :math:`B` :param b: matrix :math:`B`
@ -86,25 +85,30 @@ class Matmul_AB_2D(torch.autograd.Function):
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
A_shape = A.shape A_shape = A.shape
A = A.reshape((-1, A_shape[-1])) A = A.reshape((-1, A_shape[-1])).contiguous()
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1])).contiguous()
C_shape = (A.shape[0], B.shape[-1]) C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim): A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
A_temp = A.clone() B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
B_temp = B.clone() A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
pipeline_parallel_rank * tensor_parallel_size op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
dist.broadcast(A_temp, src=src_a, op_a.wait()
group=gpc.get_group(row_parallel_mode)) op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ for op in [op_a, op_b]:
pipeline_parallel_rank * tensor_parallel_size op.wait()
dist.broadcast(B_temp, src=src_b,
group=gpc.get_group(col_parallel_mode))
torch.addmm(C, A_temp, B_temp, out=C)
for i in range(summa_dim):
src_a = i + summa_dim * row_rank
src_b = i + summa_dim * col_rank
src_a = src_a % summa_dim
src_b = src_b % summa_dim
A_temp = A_list[src_a]
B_temp = B_list[src_b]
torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape) out = C.reshape(out_shape)
if ctx: if ctx:
@ -499,36 +503,61 @@ class _LayerNorm_2D(torch.autograd.Function):
# return input_grad, None, None, None, None, None # return input_grad, None, None, None, None, None
class _ViT_Split_Input_2D(torch.autograd.Function): class AllGatherLast(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
inputs: Tensor, inputs: Tensor,
batch_size: int,
summa_dim: int, summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor: col_parallel_mode: ParallelMode) -> Tensor:
# inputs: [b, s, h/q]
# output: [b/q, s, h/q]
ctx.BATCH_SIZE = batch_size
ctx.summa_dim = summa_dim ctx.summa_dim = summa_dim
ctx.col_parallel_mode = col_parallel_mode ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
row_rank = gpc.get_local_rank(col_parallel_mode)
output = torch.chunk(inputs, summa_dim, dim=0)[row_rank] last_dim = summa_dim * inputs.size(-1)
output = output.clone() outputs_shape = (last_dim,) + inputs.shape[:-1]
return output outputs = torch.empty(
outputs_shape, dtype=inputs.dtype, device=get_current_device())
dist.all_gather(
list(outputs.chunk(summa_dim, dim=0)),
inputs.permute(2, 0, 1).contiguous(),
group=gpc.get_group(col_parallel_mode)
)
outputs = outputs.permute(1, 2, 0).contiguous()
return outputs
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [b/q, s, h/q] grad = output_grad.chunk(ctx.summa_dim, dim=-1)[ctx.row_rank]
# grads: [b, s, h/q] return grad.contiguous(), None, None
grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
grads = torch.empty(grads_shape,
dtype=output_grad.dtype, class SplitFirst(torch.autograd.Function):
device=get_current_device())
dist.all_gather(list(grads.chunk(ctx.summa_dim, dim=0)), @staticmethod
output_grad.contiguous(), @custom_fwd(cast_inputs=torch.float16)
group=gpc.get_group(ctx.col_parallel_mode)) def forward(ctx: Any,
return grads, None, None, None inputs: Tensor,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
ctx.summa_dim = summa_dim
ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = inputs.chunk(summa_dim, dim=0)[row_rank]
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(
grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(
list(grad.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode)
)
return grad, None, None

View File

@ -5,19 +5,21 @@ import math
import torch import torch
from torch import nn as nn, Tensor, distributed as dist from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import _ViT_Split_Input_2D from colossalai.core import global_context as gpc
from ._operation import AllGatherLast, SplitFirst
from .layers import Linear2D from .layers import Linear2D
from .._common_utils import set_tensor_parallel_attribute from .._common_utils import set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module @LAYERS.register_module
@ -44,8 +46,8 @@ class ViTMLP2D(ParallelLayer):
act_func: str = 'gelu', act_func: str = 'gelu',
dropout_prob: float = 0., dropout_prob: float = 0.,
dtype=None, dtype=None,
checkpoint: bool = False checkpoint: bool = False,
): weight_init='torch'):
super().__init__() super().__init__()
assert_summa_initialization() assert_summa_initialization()
@ -53,27 +55,40 @@ class ViTMLP2D(ParallelLayer):
self.in_features = in_features self.in_features = in_features
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h. # Project to mlp_ratio * h.
self.dense_1 = Linear2D( self.dense_1 = Linear2D(
self.in_features, self.in_features,
self.mlp_ratio * self.in_features, self.mlp_ratio * self.in_features,
dtype=dtype, dtype=dtype,
init_weight=weight_init, init_bias=weight_init,
skip_bias_add=skip_dense_1_add_bias
) )
self.act = ACT2FN[act_func]
# Project back to h. # Project back to h.
self.dense_2 = Linear2D( self.dense_2 = Linear2D(
self.mlp_ratio * self.in_features, self.mlp_ratio * self.in_features,
self.in_features, self.in_features,
dtype=dtype, dtype=dtype,
init_weight=weight_init, init_bias=weight_init
) )
self.dropout = nn.Dropout(dropout_prob) self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor: def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states) if self.act == bias_gelu_impl:
intermediate_output = self.act(intermediate_output) intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output) intermediate_output = self.dropout(intermediate_output)
@ -117,8 +132,8 @@ class ViTSelfAttention2D(ParallelLayer):
attention_dropout_prob: float, attention_dropout_prob: float,
hidden_dropout_prob: float, hidden_dropout_prob: float,
dtype=None, dtype=None,
checkpoint: bool = False checkpoint: bool = False,
): weight_init='torch'):
super().__init__() super().__init__()
assert_summa_initialization() assert_summa_initialization()
@ -128,17 +143,24 @@ class ViTSelfAttention2D(ParallelLayer):
self.attention_head_size = divide(hidden_size, num_attention_heads) self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_bias = 'zero'
else:
self.init_bias = weight_init
self.query_key_value = Linear2D( self.query_key_value = Linear2D(
hidden_size, hidden_size,
3 * hidden_size, 3 * hidden_size,
dtype=dtype, dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
) )
self.attention_dropout = nn.Dropout(attention_dropout_prob) self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D( self.dense = Linear2D(
hidden_size, hidden_size,
hidden_size, hidden_size,
dtype=dtype, dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
) )
self.dropout = nn.Dropout(hidden_dropout_prob) self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
@ -146,7 +168,7 @@ class ViTSelfAttention2D(ParallelLayer):
def _forward(self, hidden_states: Tensor) -> Tensor: def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states) query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \ new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size) (self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape) query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3)) query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk( query_layer, key_layer, value_layer = torch.chunk(
@ -155,7 +177,7 @@ class ViTSelfAttention2D(ParallelLayer):
attention_scores = torch.matmul( attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2)) query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \ attention_scores = attention_scores / \
math.sqrt(self.attention_head_size) math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores) attention_probs = self.softmax(attention_scores)
@ -165,7 +187,7 @@ class ViTSelfAttention2D(ParallelLayer):
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2) context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[ new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,) :-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape) context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer) output = self.dense(context_layer)
@ -199,14 +221,22 @@ class ViTHead2D(ParallelLayer):
hidden_size, hidden_size,
num_classes, num_classes,
dtype=None, dtype=None,
): weight_init='torch'):
super().__init__() super().__init__()
assert_summa_initialization() assert_summa_initialization()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
else:
self.init_weight = weight_init
self.init_bias = weight_init
self.summa_dim = get_summa_dim_from_env() self.summa_dim = get_summa_dim_from_env()
self.linear = Linear2D( self.linear = Linear2D(
hidden_size, hidden_size,
num_classes, num_classes,
dtype=dtype, dtype=dtype,
init_weight=self.init_weight, init_bias=self.init_bias
) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -236,7 +266,8 @@ class ViTPatchEmbedding2D(ParallelLayer):
patch_size, patch_size,
embed_dim, embed_dim,
in_chans=3, in_chans=3,
flatten=True): flatten=True,
weight_init='torch'):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
@ -249,39 +280,28 @@ class ViTPatchEmbedding2D(ParallelLayer):
img_size[1] // patch_size[1]) img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten self.flatten = flatten
self.embed_dim = embed_dim // self.summa_dim self.embed_dim = embed_dim // (self.summa_dim ** 2)
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
# ensure the partitions are initialized differently
self.proj = nn.Conv2d(in_chans, self.proj = nn.Conv2d(in_chans,
self.embed_dim, self.embed_dim,
kernel_size=patch_size, kernel_size=patch_size,
stride=patch_size stride=patch_size,
device=get_current_device()
) )
self._set_tensor_parallel_attribute()
# sync if weight_init == 'jax':
self._broadcast_conv_params() with seed(ParallelMode.TENSOR):
self.proj.weight.register_hook(self._sync_grad_during_backward) fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
self.proj.bias.register_hook(self._sync_grad_during_backward) std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
def _set_tensor_parallel_attribute(self): def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.proj.weight) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute(self.proj.bias) set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(self.proj.weight, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
dist.broadcast(self.proj.bias, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
def _sync_grad_during_backward(self, grad: Tensor) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape B, C, H, W = x.shape
@ -293,6 +313,24 @@ class ViTPatchEmbedding2D(ParallelLayer):
return x return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
x = AllGatherLast.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = SplitFirst.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
return x
@LAYERS.register_module @LAYERS.register_module
class ViTTokenFuser2D(ParallelLayer): class ViTTokenFuser2D(ParallelLayer):
""" """
@ -328,64 +366,32 @@ class ViTTokenFuser2D(ParallelLayer):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros( self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim // self.summa_dim)) (1, 1, self.embed_dim // (self.summa_dim ** 2)),
self.pos_embed = nn.Parameter(torch.zeros( device=get_current_device()))
1, self.num_patches + 1, self.embed_dim // self.summa_dim)) self.pos_embed = nn.Parameter(torch.empty(
(1, self.num_patches + 1, self.embed_dim // (self.summa_dim ** 2)),
device=get_current_device()))
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.pos_embed, std=.02)
# move to cuda before broadcast
self.to(get_current_device())
# sync param in both forward and backward
_cls_token = self.cls_token.view(-1)
_pos_embed = self.pos_embed.view(-1)
self._param = torch.cat([_cls_token, _pos_embed], dim=0)
self._broadcast_params(self._param)
self._param.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute() self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self): def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.cls_token) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute(self.pos_embed) set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
def _broadcast_params(self, param) -> None:
" broadcast to all column ranks for data consistency "
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
col_group = gpc.get_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(param, src=ranks_in_col[0],
group=col_group)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks # stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(x.shape[0], -1, -1) cls_token = AllGatherLast.apply(
self.cls_token, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1) x = torch.cat((cls_token, x), dim=1)
pos_embed = AllGatherLast.apply(
self.pos_embed, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = x + pos_embed
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
x = self.pos_drop(x + self.pos_embed) x = self.pos_drop(x)
return x return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
batch_size = x.size(0)
return _ViT_Split_Input_2D.apply(
x,
batch_size,
self.summa_dim,
ParallelMode.PARALLEL_2D_COL
)

View File

@ -11,7 +11,7 @@ from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D
from ._utils import get_summa_dim_from_env, assert_summa_initialization from ._utils import get_summa_dim_from_env, assert_summa_initialization
from .._common_utils import divide, set_tensor_parallel_attribute from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
@ -36,8 +36,9 @@ class Linear2D(ParallelLayer):
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype=None, dtype=None,
skip_bias_add: bool = False skip_bias_add: bool = False,
): init_weight='torch',
init_bias='torch'):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
@ -72,31 +73,45 @@ class Linear2D(ParallelLayer):
self.register_parameter('bias', None) self.register_parameter('bias', None)
# initialize parameters # initialize parameters
self.reset_parameters() with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute(self.bias) set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def reset_parameters(self) -> None: def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting # setting
fan_in = self.in_features fan_in, fan_out = self.in_features, self.out_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
# init weight # init weight
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) if init_weight == 'torch':
bound = math.sqrt(3.0) * std a = math.sqrt(5)
with seed(ParallelMode.TENSOR): nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound) init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias # init bias
if self.bias is not None: if self.bias is not None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 if init_bias == 'torch':
with seed(ParallelMode.TENSOR): bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound) init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q] # input: [m/q, n/q, k/q]
@ -192,28 +207,19 @@ class LayerNorm2D(ParallelLayer):
# create parameters # create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if self.row_rank == 0: self.gamma = Parameter(torch.ones(
self.gamma = Parameter(torch.ones( self.partitioned_partition,
self.partitioned_partition, **factory_kwargs))
**factory_kwargs)) self.beta = Parameter(torch.zeros(
self.beta = Parameter(torch.zeros( self.partitioned_partition,
self.partitioned_partition, **factory_kwargs))
**factory_kwargs))
else:
self.gamma = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.beta = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.gamma) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute(self.beta) set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
with torch.no_grad(): with torch.no_grad():

View File

@ -1,11 +1,10 @@
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Sum_2p5D, Add_Bias_2p5D from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Add_Bias_2p5D
from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D
from ._vit import (ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D, from ._vit import ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D, ViTInputSplitter2p5D
ViTInputSplitter2p5D)
from .layers import Linear2p5D, LayerNorm2p5D from .layers import Linear2p5D, LayerNorm2p5D
__all__ = [ __all__ = [
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Sum_2p5D', 'Add_Bias_2p5D', 'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Add_Bias_2p5D',
'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D', 'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D',
'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D', 'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D',
'ViTInputSplitter2p5D', 'ViTInputSplitter2p5D',

View File

@ -6,7 +6,8 @@ from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device, empty_cache from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
def get_parallel_group(parallel_mode: ParallelMode): def get_parallel_group(parallel_mode: ParallelMode):
@ -26,18 +27,17 @@ class Matmul_AB_2p5D(torch.autograd.Function):
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
tesseract_dim: int, tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...], out_shape: Tuple[int, ...],
row_rank: int, row_rank: int,
col_rank: int, col_rank: int,
dep_rank: int, dep_rank: int,
row_parallel_mode: ParallelMode, row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int, data_parallel_rank: int,
pipeline_parallel_rank: int, pipeline_parallel_rank: int,
pipeline_parallel_size: int, pipeline_parallel_size: int,
@ -49,41 +49,43 @@ class Matmul_AB_2p5D(torch.autograd.Function):
assert A.shape[-1] == B.shape[-2], \ assert A.shape[-1] == B.shape[-2], \
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
empty_cache()
if ctx: if ctx:
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
A_shape = A.shape A_shape = A.shape
A = A.reshape((-1, A_shape[-1])) A = A.reshape((-1, A_shape[-1])).contiguous()
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1])).contiguous()
C_shape = (A.shape[0], B.shape[-1]) C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim): A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
A_temp = A.clone() B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
B_temp = B.clone() A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
src_a = i + row_rank * tesseract_dim + dep_rank * ( B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
dist.broadcast(A_temp, src=src_a, op_a.wait()
group=get_parallel_group(row_parallel_mode)) op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
src_b = col_rank + i * tesseract_dim + dep_rank * ( for op in [op_a, op_b]:
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size op.wait()
dist.broadcast(B_temp, src=src_b,
group=get_parallel_group(col_parallel_mode))
torch.addmm(C, A_temp, B_temp, out=C)
for i in range(tesseract_dim):
src_a = i + tesseract_dim * row_rank
src_b = i + tesseract_dim * col_rank
src_a = src_a % tesseract_dim
src_b = src_b % tesseract_dim
A_temp = A_list[src_a]
B_temp = B_list[src_b]
torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape) out = C.reshape(out_shape)
if ctx: if ctx:
ctx.tesseract_dim = tesseract_dim ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank ctx.row_rank = row_rank
ctx.col_rank = col_rank ctx.col_rank = col_rank
ctx.dep_rank = dep_rank ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape ctx.A_shape = A_shape
ctx.B_shape = B_shape ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank ctx.data_parallel_rank = data_parallel_rank
@ -94,34 +96,32 @@ class Matmul_AB_2p5D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2p5D.forward( with torch.no_grad():
None, A_grad = Matmul_ABT_2p5D.apply(
output_grad, B, output_grad, B,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape, ctx.tesseract_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank, ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode, ctx.row_parallel_mode,
ctx.col_parallel_mode, ctx.col_parallel_mode,
ctx.dep_parallel_mode, ctx.data_parallel_rank,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size
ctx.tensor_parallel_size )
) B_grad = Matmul_ATB_2p5D.apply(
B_grad = Matmul_ATB_2p5D.forward( A, output_grad,
None, ctx.tesseract_dim, ctx.B_shape,
A, output_grad, ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape, ctx.row_parallel_mode,
ctx.row_rank, ctx.col_rank, ctx.dep_rank, ctx.col_parallel_mode,
ctx.row_parallel_mode, ctx.data_parallel_rank,
ctx.col_parallel_mode, ctx.pipeline_parallel_rank,
ctx.dep_parallel_mode, ctx.pipeline_parallel_size,
ctx.data_parallel_rank, ctx.tensor_parallel_size
ctx.pipeline_parallel_rank, )
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@ -130,18 +130,17 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
tesseract_dim: int, tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...], out_shape: Tuple[int, ...],
row_rank: int, row_rank: int,
col_rank: int, col_rank: int,
dep_rank: int, dep_rank: int,
row_parallel_mode: ParallelMode, row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int, data_parallel_rank: int,
pipeline_parallel_rank: int, pipeline_parallel_rank: int,
pipeline_parallel_size: int, pipeline_parallel_size: int,
@ -151,7 +150,6 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
assert A.shape[-1] == B.shape[-1], \ assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
empty_cache()
if ctx: if ctx:
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
@ -180,13 +178,11 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
if ctx: if ctx:
ctx.tesseract_dim = tesseract_dim ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank ctx.row_rank = row_rank
ctx.col_rank = col_rank ctx.col_rank = col_rank
ctx.dep_rank = dep_rank ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape ctx.A_shape = A_shape
ctx.B_shape = B_shape ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank ctx.data_parallel_rank = data_parallel_rank
@ -197,34 +193,32 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors A, B = ctx.saved_tensors
A_grad = Matmul_AB_2p5D.forward( with torch.no_grad():
None, A_grad = Matmul_AB_2p5D.apply(
output_grad, B, output_grad, B,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape, ctx.tesseract_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank, ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode, ctx.row_parallel_mode,
ctx.col_parallel_mode, ctx.col_parallel_mode,
ctx.dep_parallel_mode, ctx.data_parallel_rank,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size
ctx.tensor_parallel_size )
) B_grad = Matmul_ATB_2p5D.apply(
B_grad = Matmul_ATB_2p5D.forward( output_grad, A,
None, ctx.tesseract_dim, ctx.B_shape,
output_grad, A, ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape, ctx.row_parallel_mode,
ctx.row_rank, ctx.col_rank, ctx.dep_rank, ctx.col_parallel_mode,
ctx.row_parallel_mode, ctx.data_parallel_rank,
ctx.col_parallel_mode, ctx.pipeline_parallel_rank,
ctx.dep_parallel_mode, ctx.pipeline_parallel_size,
ctx.data_parallel_rank, ctx.tensor_parallel_size
ctx.pipeline_parallel_rank, )
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@ -233,18 +227,17 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
tesseract_dim: int, tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...], out_shape: Tuple[int, ...],
row_rank: int, row_rank: int,
col_rank: int, col_rank: int,
dep_rank: int, dep_rank: int,
row_parallel_mode: ParallelMode, row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int, data_parallel_rank: int,
pipeline_parallel_rank: int, pipeline_parallel_rank: int,
pipeline_parallel_size: int, pipeline_parallel_size: int,
@ -253,7 +246,6 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
assert A.shape[-2] == B.shape[-2], \ assert A.shape[-2] == B.shape[-2], \
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
empty_cache()
if ctx: if ctx:
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
@ -284,13 +276,11 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
if ctx: if ctx:
ctx.tesseract_dim = tesseract_dim ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank ctx.row_rank = row_rank
ctx.col_rank = col_rank ctx.col_rank = col_rank
ctx.dep_rank = dep_rank ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape ctx.A_shape = A_shape
ctx.B_shape = B_shape ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank ctx.data_parallel_rank = data_parallel_rank
@ -301,34 +291,32 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2p5D.forward( with torch.no_grad():
None, A_grad = Matmul_ABT_2p5D.apply(
B, output_grad, B, output_grad,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape, ctx.tesseract_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank, ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode, ctx.row_parallel_mode,
ctx.col_parallel_mode, ctx.col_parallel_mode,
ctx.dep_parallel_mode, ctx.data_parallel_rank,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size
ctx.tensor_parallel_size )
) B_grad = Matmul_AB_2p5D.apply(
B_grad = Matmul_AB_2p5D.forward( A, output_grad,
None, ctx.tesseract_dim, ctx.B_shape,
A, output_grad, ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape, ctx.row_parallel_mode,
ctx.row_rank, ctx.col_rank, ctx.dep_rank, ctx.col_parallel_mode,
ctx.row_parallel_mode, ctx.data_parallel_rank,
ctx.col_parallel_mode, ctx.pipeline_parallel_rank,
ctx.dep_parallel_mode, ctx.pipeline_parallel_size,
ctx.data_parallel_rank, ctx.tensor_parallel_size
ctx.pipeline_parallel_rank, )
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@ -337,18 +325,16 @@ class Add_Bias_2p5D(torch.autograd.Function):
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
input: Tensor, input: Tensor,
bias: Tensor, bias: Tensor,
output_size_per_partition: int, output_size_per_partition: int,
tesseract_dim: int, tesseract_dim: int,
tesseract_dep: int,
row_rank: int, row_rank: int,
col_rank: int, col_rank: int,
dep_rank: int, dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
skip_bias_add: bool, skip_bias_add: bool,
data_parallel_rank: int, data_parallel_rank: int,
pipeline_parallel_rank: int, pipeline_parallel_rank: int,
@ -371,10 +357,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
ctx.col_rank = col_rank ctx.col_rank = col_rank
ctx.dep_rank = dep_rank ctx.dep_rank = dep_rank
ctx.tesseract_dim = tesseract_dim ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.bias = skip_bias_add ctx.bias = skip_bias_add
ctx.data_parallel_rank = data_parallel_rank ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank ctx.pipeline_parallel_rank = pipeline_parallel_rank
@ -388,15 +371,13 @@ class Add_Bias_2p5D(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, output_grad): @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_rank = ctx.row_rank row_rank = ctx.row_rank
col_rank = ctx.col_rank col_rank = ctx.col_rank
dep_rank = ctx.dep_rank dep_rank = ctx.dep_rank
tesseract_dim = ctx.tesseract_dim tesseract_dim = ctx.tesseract_dim
tesseract_dep = ctx.tesseract_dep
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode col_parallel_mode = ctx.col_parallel_mode
dep_parallel_mode = ctx.dep_parallel_mode
data_parallel_rank = ctx.data_parallel_rank data_parallel_rank = ctx.data_parallel_rank
pipeline_parallel_rank = ctx.pipeline_parallel_rank pipeline_parallel_rank = ctx.pipeline_parallel_rank
pipeline_parallel_size = ctx.pipeline_parallel_size pipeline_parallel_size = ctx.pipeline_parallel_size
@ -428,29 +409,25 @@ class Add_Bias_2p5D(torch.autograd.Function):
class _LayerNorm_2p5D(torch.autograd.Function): class _LayerNorm_2p5D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, def forward(ctx: Any,
input: Tensor, input: Tensor,
E_x: Tensor, E_x: Tensor,
Var_x: Tensor, Var_x: Tensor,
hidden_size: int, hidden_size: int,
row_parallel_mode: ParallelMode, row_parallel_mode: ParallelMode) -> Tensor:
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode) -> Tensor:
input = input - E_x input = input - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx.hidden_size = hidden_size ctx.hidden_size = hidden_size
output = input * Var_x output = input * Var_x
ctx.save_for_backward(output, Var_x) ctx.save_for_backward(output, Var_x)
ctx.row_parallel_mode = row_parallel_mode ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
return output return output
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, output_grad): def backward(ctx, output_grad):
row_parallel_mode = ctx.row_parallel_mode row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
dep_parallel_mode = ctx.dep_parallel_mode
x, Var_x = ctx.saved_tensors x, Var_x = ctx.saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
with torch.no_grad(): with torch.no_grad():
@ -473,63 +450,122 @@ class _LayerNorm_2p5D(torch.autograd.Function):
return input_grad, None, None, None, None, None, None return input_grad, None, None, None, None, None, None
class Sum_2p5D(torch.autograd.Function): # class Sum_2p5D(torch.autograd.Function):
"""Compute the sum of input tensors # """Compute the sum of input tensors
""" # """
# @staticmethod
# def forward(ctx,
# inputs,
# dim,
# tesseract_dim,
# row_parallel_mode,
# keepdim=False):
# # input: [b/q, s, h/q]
# ctx.save_for_backward(inputs)
# # sum: [b/q, s]
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
# torch.distributed.all_reduce(
# out, group=gpc.get_group(row_parallel_mode))
# return out
# @staticmethod
# def backward(ctx, output_grad):
# with torch.no_grad():
# inputs = ctx.saved_tensors
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
# return input_grad, None, None, None, None, None
# class _ViT_Split_2p5D(torch.autograd.Function):
# @staticmethod
# @custom_fwd(cast_inputs=torch.float16)
# def forward(ctx, inputs, batch_size,
# tesseract_dim, tesseract_dep,
# xz_parallel_mode):
# # inputs: [b, s, h/q]
# # output: [b/dq, s, h/q]
# ctx.BATCH_SIZE = batch_size
# ctx.tesseract_dim = tesseract_dim
# ctx.tesseract_dep = tesseract_dep
# ctx.xz_parallel_mode = xz_parallel_mode
# xz_rank = gpc.get_local_rank(xz_parallel_mode)
# output = torch.chunk(inputs, tesseract_dep *
# tesseract_dim, dim=0)[xz_rank]
# output = output.clone()
# return output
# @staticmethod
# @custom_bwd
# def backward(ctx, output_grad):
# # output_grad: [b/dq, s, h/q]
# # grads: [b, s, h/q]
# # *
# grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
# grads = torch.empty(grads_shape,
# dtype=output_grad.dtype,
# device=get_current_device())
# dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
# output_grad.contiguous(),
# group=get_parallel_group(ctx.xz_parallel_mode))
# return grads, None, None, None, None
class AllGatherLast(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, @custom_fwd(cast_inputs=torch.float16)
inputs, def forward(ctx: Any,
dim, inputs: Tensor,
tesseract_dim, tesseract_dim: int,
row_parallel_mode, col_parallel_mode: ParallelMode) -> Tensor:
keepdim=False):
# input: [b/q, s, h/q]
empty_cache()
ctx.save_for_backward(inputs)
# sum: [b/q, s]
out = torch.sum(inputs, dim=dim, keepdim=keepdim)
torch.distributed.all_reduce(
out, group=gpc.get_group(row_parallel_mode))
return out
@staticmethod
def backward(ctx, output_grad):
with torch.no_grad():
inputs = ctx.saved_tensors
input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
return input_grad, None, None, None, None, None
class _ViT_Split_2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, batch_size,
tesseract_dim, tesseract_dep,
xz_parallel_mode):
# inputs: [b, s, h/q]
# output: [b/dq, s, h/q]
empty_cache()
ctx.batch_size = batch_size
ctx.tesseract_dim = tesseract_dim ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
ctx.xz_parallel_mode = xz_parallel_mode
xz_rank = gpc.get_local_rank(xz_parallel_mode) last_dim = tesseract_dim * inputs.size(-1)
output = torch.chunk(inputs, tesseract_dep * outputs_shape = (last_dim,) + inputs.shape[:-1]
tesseract_dim, dim=0)[xz_rank] outputs = torch.empty(
output = output.clone() outputs_shape, dtype=inputs.dtype, device=get_current_device())
return output dist.all_gather(
list(outputs.chunk(tesseract_dim, dim=0)),
inputs.permute(2, 0, 1).contiguous(),
group=gpc.get_group(col_parallel_mode)
)
outputs = outputs.permute(1, 2, 0).contiguous()
return outputs
@staticmethod @staticmethod
def backward(ctx, output_grad): @custom_bwd
# output_grad: [b/dq, s, h/q] def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# grads: [b, s, h/q] grad = output_grad.chunk(ctx.tesseract_dim, dim=-1)[ctx.row_rank]
# * return grad.contiguous(), None, None
grads_shape = (ctx.batch_size,) + output_grad.shape[1:]
grads = torch.empty(grads_shape,
dtype=output_grad.dtype, class SplitFirst(torch.autograd.Function):
device=get_current_device())
dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)), @staticmethod
output_grad.contiguous(), @custom_fwd(cast_inputs=torch.float16)
group=get_parallel_group(ctx.xz_parallel_mode)) def forward(ctx: Any,
return grads, None, None, None, None inputs: Tensor,
tesseract_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
ctx.tesseract_dim = tesseract_dim
ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = inputs.chunk(tesseract_dim, dim=0)[row_rank]
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(
grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(
list(grad.chunk(ctx.tesseract_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode)
)
return grad, None, None

View File

@ -12,10 +12,11 @@ from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env get_tesseract_dim_dep_from_env
from .layers import Linear2p5D, LayerNorm2p5D from .layers import Linear2p5D, LayerNorm2p5D
from .._common_utils import ACT2FN from .._common_utils import ACT2FN
from ..base_layer import ParallelLayer
@LAYERS.register_module @LAYERS.register_module
class TransformerMLP2p5D(nn.Module): class TransformerMLP2p5D(ParallelLayer):
""" """
MLP will take the input with h hidden state, project it to mlp_ratio * h MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the hidden dimension, perform nonlinear transformation, and project the
@ -36,21 +37,24 @@ class TransformerMLP2p5D(nn.Module):
def __init__(self, def __init__(self,
in_features: int, in_features: int,
mlp_ratio: int, mlp_ratio: int = 4.0,
act_func: str = 'gelu', act_func: str = 'gelu',
dropout_prob: float = 0., dropout_prob: float = 0.,
dtype=None, dtype=None,
skip_bias_add: bool = False
): ):
super().__init__() super().__init__()
assert_tesseract_initialization() assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.in_features = in_features self.in_features = in_features
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio. # Project to h * mlp_ratio.
self.dense_1 = Linear2p5D( self.dense_1 = Linear2p5D(
in_features, in_features,
mlp_ratio * in_features, int(mlp_ratio * in_features),
dtype=dtype dtype=dtype,
skip_bias_add=skip_bias_add
) )
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \ assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
@ -59,24 +63,34 @@ class TransformerMLP2p5D(nn.Module):
# Project back to h. # Project back to h.
self.dense_2 = Linear2p5D( self.dense_2 = Linear2p5D(
mlp_ratio * in_features, int(mlp_ratio * in_features),
in_features, in_features,
dtype=dtype dtype=dtype,
skip_bias_add=skip_bias_add
) )
self.dropout = nn.Dropout(dropout_prob) self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2p5D(in_features, dtype=dtype) self.layernorm = LayerNorm2p5D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
intermediate_output = self.dense_1(x) if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output) intermediate_output = self.activation_func(intermediate_output)
output = self.dense_2(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
output = self.dropout(output) output = self.dropout(output)
output = self.layernorm(x + output) output = self.layernorm(x + output)
return output return output
@LAYERS.register_module @LAYERS.register_module
class TransformerSelfAttention2p5D(nn.Module): class TransformerSelfAttention2p5D(ParallelLayer):
"""Self attention layer for 2.5D parallel Transformer """Self attention layer for 2.5D parallel Transformer
:param hidden_size: hidden size :param hidden_size: hidden size
@ -92,10 +106,10 @@ class TransformerSelfAttention2p5D(nn.Module):
""" """
def __init__(self, def __init__(self,
hidden_size, hidden_size: int,
num_attention_heads, num_attention_heads: int,
attention_dropout_prob, attention_dropout_prob: float,
hidden_dropout_prob, hidden_dropout_prob: float,
dtype=None, dtype=None,
): ):
super().__init__() super().__init__()
@ -127,7 +141,7 @@ class TransformerSelfAttention2p5D(nn.Module):
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states) query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \ new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size) (self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape) query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3)) query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk( query_layer, key_layer, value_layer = torch.chunk(
@ -136,7 +150,7 @@ class TransformerSelfAttention2p5D(nn.Module):
attention_scores = torch.matmul( attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2)) query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \ attention_scores = attention_scores / \
math.sqrt(self.attention_head_size) math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores) attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
@ -144,7 +158,7 @@ class TransformerSelfAttention2p5D(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous() context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[ new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,) :-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer) output = self.dense(context_layer)
@ -155,7 +169,7 @@ class TransformerSelfAttention2p5D(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class TransformerLayer2p5D(nn.Module): class TransformerLayer2p5D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer """Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size :param hidden_size: hidden size
@ -175,10 +189,10 @@ class TransformerLayer2p5D(nn.Module):
""" """
def __init__(self, def __init__(self,
hidden_size, hidden_size: int,
num_attention_heads, num_attention_heads: int,
act_func='gelu', act_func: str = 'gelu',
mlp_ratio=4, mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0., attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0., hidden_dropout_prob: float = 0.,
dtype=None, dtype=None,

View File

@ -5,22 +5,25 @@ import math
import torch import torch
from torch import nn as nn, Tensor, distributed as dist from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context.parallel_mode import ParallelMode from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import _ViT_Split_2p5D from ._operation import AllGatherLast, SplitFirst
from ._utils import assert_tesseract_initialization, \ from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env get_tesseract_dim_dep_from_env
from .layers import Linear2p5D from .layers import Linear2p5D
from .._common_utils import ACT2FN, divide, CheckpointModule from ..base_layer import ParallelLayer
from .._common_utils import set_tensor_parallel_attribute from ..fused_bias_gelu import bias_gelu_impl
from .._common_utils import (ACT2FN, divide, to_2tuple,
set_tensor_parallel_attribute_by_partition)
@LAYERS.register_module @LAYERS.register_module
class ViTMLP2p5D(CheckpointModule): class ViTMLP2p5D(ParallelLayer):
"""MLP layer for 2.5D parallel Vision Transformer """MLP layer for 2.5D parallel Vision Transformer
:param in_features: size of each input sample :param in_features: size of each input sample
@ -43,19 +46,32 @@ class ViTMLP2p5D(CheckpointModule):
act_func: str = 'gelu', act_func: str = 'gelu',
dropout_prob: float = 0., dropout_prob: float = 0.,
dtype=None, dtype=None,
checkpoint: bool = False checkpoint: bool = False,
weight_init='torch'
): ):
super().__init__(checkpoint=checkpoint) super().__init__()
assert_tesseract_initialization() assert_tesseract_initialization()
self.in_features = in_features self.in_features = in_features
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h. # Project to mlp_ratio * h.
self.dense_1 = Linear2p5D( self.dense_1 = Linear2p5D(
self.in_features, self.in_features,
self.mlp_ratio * self.in_features, self.mlp_ratio * self.in_features,
dtype=dtype, dtype=dtype,
init_weight=weight_init,
init_bias=weight_init,
skip_bias_add=skip_dense_1_add_bias
) )
self.act = ACT2FN[act_func] self.act = ACT2FN[act_func]
@ -65,20 +81,39 @@ class ViTMLP2p5D(CheckpointModule):
self.mlp_ratio * self.in_features, self.mlp_ratio * self.in_features,
self.in_features, self.in_features,
dtype=dtype, dtype=dtype,
init_weight=weight_init,
init_bias=weight_init
) )
self.dropout = nn.Dropout(dropout_prob) self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor: def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states) if self.act == bias_gelu_impl:
intermediate_output = self.act(intermediate_output) intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.dropout(intermediate_output) intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output) output = self.dense_2(intermediate_output)
output = self.dropout(output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module @LAYERS.register_module
class ViTSelfAttention2p5D(CheckpointModule): class ViTSelfAttention2p5D(ParallelLayer):
"""Self-attention layer for 2.5D parallel Vision Transformer """Self-attention layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size :param hidden_size: hidden size
@ -101,9 +136,10 @@ class ViTSelfAttention2p5D(CheckpointModule):
attention_dropout_prob, attention_dropout_prob,
hidden_dropout_prob, hidden_dropout_prob,
dtype=None, dtype=None,
checkpoint: bool = False checkpoint: bool = False,
weight_init='torch'
): ):
super().__init__(checkpoint=checkpoint) super().__init__()
assert_tesseract_initialization() assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
@ -112,19 +148,30 @@ class ViTSelfAttention2p5D(CheckpointModule):
num_attention_heads, self.tesseract_dim) # * num_attention_heads, self.tesseract_dim) # *
self.attention_head_size = divide(hidden_size, num_attention_heads) self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_bias = 'zero'
else:
self.init_bias = weight_init
self.query_key_value = Linear2p5D( self.query_key_value = Linear2p5D(
hidden_size, hidden_size,
3 * hidden_size, 3 * hidden_size,
dtype=dtype, dtype=dtype,
init_weight=weight_init,
init_bias=self.init_bias
) )
self.attention_dropout = nn.Dropout(attention_dropout_prob) self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2p5D( self.dense = Linear2p5D(
hidden_size, hidden_size,
hidden_size, hidden_size,
dtype=dtype, dtype=dtype,
init_weight=weight_init,
init_bias=self.init_bias
) )
self.dropout = nn.Dropout(hidden_dropout_prob) self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor: def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states) query_key_value = self.query_key_value(hidden_states)
@ -140,8 +187,10 @@ class ViTSelfAttention2p5D(CheckpointModule):
attention_scores = attention_scores / \ attention_scores = attention_scores / \
math.sqrt(self.attention_head_size) math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores) attention_probs = self.softmax(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2) context_layer = context_layer.transpose(1, 2)
@ -150,12 +199,22 @@ class ViTSelfAttention2p5D(CheckpointModule):
context_layer = context_layer.reshape(new_context_layer_shape) context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer) output = self.dense(context_layer)
output = self.dropout(output) with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module @LAYERS.register_module
class ViTHead2p5D(nn.Module): class ViTHead2p5D(ParallelLayer):
"""Output layer for 2.5D parallel Vision Transformer """Output layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size :param hidden_size: hidden size
@ -170,13 +229,24 @@ class ViTHead2p5D(nn.Module):
hidden_size, hidden_size,
num_classes, num_classes,
dtype=None, dtype=None,
weight_init='torch'
): ):
super().__init__() super().__init__()
assert_tesseract_initialization() assert_tesseract_initialization()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
else:
self.init_weight = weight_init
self.init_bias = weight_init
self.linear = Linear2p5D( self.linear = Linear2p5D(
hidden_size, hidden_size,
num_classes, num_classes,
dtype=dtype, dtype=dtype,
init_weight=self.init_weight,
init_bias=self.init_bias
) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -186,7 +256,7 @@ class ViTHead2p5D(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class ViTPatchEmbedding2p5D(nn.Module): class ViTPatchEmbedding2p5D(ParallelLayer):
""" 2.5D Image to Patch Embedding """ 2.5D Image to Patch Embedding
:param img_size: iamge size :param img_size: iamge size
@ -206,7 +276,8 @@ class ViTPatchEmbedding2p5D(nn.Module):
patch_size, patch_size,
embed_dim, embed_dim,
in_chans=3, in_chans=3,
flatten=True): flatten=True,
weight_init='torch'):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
@ -219,34 +290,28 @@ class ViTPatchEmbedding2p5D(nn.Module):
img_size[1] // patch_size[1]) img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten self.flatten = flatten
self.embed_dim = embed_dim // self.tesseract_dim # * self.embed_dim = embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2) # *
self.proj = nn.Conv2d(in_chans, with seed(ParallelMode.TENSOR):
self.embed_dim, self.proj = nn.Conv2d(in_chans,
kernel_size=patch_size, self.embed_dim,
stride=patch_size, kernel_size=patch_size,
) stride=patch_size,
device=get_current_device()
)
self._set_tensor_parallel_attribute()
# move self to cuda before sync if weight_init == 'jax':
self.to(get_current_device()) with seed(ParallelMode.TENSOR):
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
# sync def _set_tensor_parallel_attribute(self):
self._broadcast_conv_params() num_partition = gpc.get_world_size(ParallelMode.TENSOR)
self.proj.weight.register_hook(self._sync_grad_during_backward) set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
self.proj.bias.register_hook(self._sync_grad_during_backward) set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def _broadcast_conv_params(self) -> None:
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
dist.broadcast(self.proj.weight, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
dist.broadcast(self.proj.bias, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
def _sync_grad_during_backward(self, grad: Tensor) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2P5D_XZ))
grad = grad / self.tesseract_dim / self.tesseract_dep # *
return grad
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape B, C, H, W = x.shape
@ -259,7 +324,25 @@ class ViTPatchEmbedding2p5D(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class ViTTokenFuser2p5D(nn.Module): class ViTInputSplitter2p5D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
def forward(self, x: Tensor) -> Tensor:
x = AllGatherLast.apply(
x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
x = SplitFirst.apply(
x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
return x
@LAYERS.register_module
class ViTTokenFuser2p5D(ParallelLayer):
""" """
Fuse cls token and pos embedding to the input Fuse cls token and pos embedding to the input
@ -293,59 +376,46 @@ class ViTTokenFuser2p5D(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros( self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim // self.tesseract_dim)) # * (1, 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
self.pos_embed = nn.Parameter(torch.zeros( device=get_current_device()))
1, self.num_patches + 1, self.embed_dim // self.tesseract_dim)) # * self.pos_embed = nn.Parameter(torch.empty(
(1, self.num_patches + 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
device=get_current_device()))
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.pos_embed, std=.02)
# move to cuda before broadcast
self.to(get_current_device())
self._broadcast_params()
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute() self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self): def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.cls_token) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute(self.pos_embed) set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
def _broadcast_params(self) -> None: def _broadcast_params(self, param) -> None:
" broadcast to all column ranks for data consistency " " broadcast to all column ranks for data consistency "
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ) if self.tesseract_dep > 1:
dist.broadcast(self.cls_token, src=xz_rank[0], xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)) xz_group = gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)
dist.broadcast(self.pos_embed, src=xz_rank[0], dist.broadcast(param, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)) group=xz_group)
def _sync_grad_hook(self, grad) -> None: def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group( dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2P5D_XZ)) ParallelMode.PARALLEL_2P5D_XZ))
grad = grad / self.tesseract_dim / self.tesseract_dep # * grad = grad / self.tesseract_dim # / self.tesseract_dep # *
return grad return grad
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks # stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(x.shape[0], -1, -1) cls_token = AllGatherLast.apply(
self.cls_token, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1) x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
pos_embed = AllGatherLast.apply(
self.pos_embed, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
x = x + pos_embed
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x)
return x return x
@LAYERS.register_module
class ViTInputSplitter2p5D(nn.Module):
def __init__(self):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
def forward(self, x: Tensor) -> Tensor:
batch_size = x.size(0)
return _ViT_Split_2p5D.apply(
x,
batch_size,
self.tesseract_dim,
self.tesseract_dep,
ParallelMode.PARALLEL_2P5D_XZ,
)

View File

@ -10,7 +10,7 @@ from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2p5D, Add_Bias_2p5D, _LayerNorm_2p5D from ._operation import Matmul_AB_2p5D, Add_Bias_2p5D, _LayerNorm_2p5D
from ._utils import get_tesseract_dim_dep_from_env, assert_tesseract_initialization from ._utils import get_tesseract_dim_dep_from_env, assert_tesseract_initialization
from .._common_utils import divide, set_tensor_parallel_attribute from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
@ -33,7 +33,9 @@ class Linear2p5D(ParallelLayer):
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype=None, dtype=None,
skip_bias_add: bool = False skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
): ):
super().__init__() super().__init__()
@ -46,7 +48,7 @@ class Linear2p5D(ParallelLayer):
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension # partitioning dimension
self.input_size_per_partition = divide(in_features, self.tesseract_dim) self.input_size_per_partition = divide(in_features, self.tesseract_dim)
@ -69,46 +71,59 @@ class Linear2p5D(ParallelLayer):
self.register_parameter('bias', None) self.register_parameter('bias', None)
# initialize parameters # initialize parameters
self.reset_parameters() with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute(self.bias) set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def reset_parameters(self) -> None: def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting # setting
fan_in = self.in_features fan_in, fan_out = self.in_features, self.out_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
# init weight # init weight
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) if init_weight == 'torch':
bound = math.sqrt(3.0) * std a = math.sqrt(5)
with seed(ParallelMode.TENSOR): nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound) init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias # init bias
if self.bias is not None: if self.bias is not None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 if init_bias == 'torch':
with seed(ParallelMode.TENSOR): bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound) init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q] # input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q] # output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2p5D.apply( output = Matmul_AB_2p5D.apply(
x, x,
self.weight, self.weight,
self.tesseract_dim, self.tesseract_dim,
self.tesseract_dep,
out_shape, out_shape,
self.row_rank, self.col_rank, self.dep_rank, self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL, ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
self.data_parallel_rank, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_rank,
self.pipeline_parallel_size, self.pipeline_parallel_size,
@ -121,11 +136,9 @@ class Linear2p5D(ParallelLayer):
None, None,
self.bias, self.bias,
self.hidden_size_per_partition, self.hidden_size_per_partition,
self.tesseract_dim, self.tesseract_dep, self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank, self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL, ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True, True,
self.data_parallel_rank, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_rank,
@ -138,11 +151,9 @@ class Linear2p5D(ParallelLayer):
output, output,
self.bias, self.bias,
self.hidden_size_per_partition, self.hidden_size_per_partition,
self.tesseract_dim, self.tesseract_dep, self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank, self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL, ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
False, False,
self.data_parallel_rank, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_rank,
@ -168,6 +179,7 @@ class LayerNorm2p5D(ParallelLayer):
:param dtype: The dtype of parameters, defaults to None :param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
""" """
def __init__(self, def __init__(self,
normalized_shape: int, normalized_shape: int,
eps: float = 1e-05, eps: float = 1e-05,
@ -184,7 +196,7 @@ class LayerNorm2p5D(ParallelLayer):
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension # partitioning dimension
self.partitioned_partition = divide( self.partitioned_partition = divide(
@ -193,27 +205,19 @@ class LayerNorm2p5D(ParallelLayer):
# create parameters # create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if self.row_rank == 0: self.gamma = Parameter(torch.ones(
self.gamma = Parameter(torch.ones( self.partitioned_partition,
self.partitioned_partition, **factory_kwargs))
**factory_kwargs)) self.beta = Parameter(torch.zeros(
self.beta = Parameter(torch.zeros( self.partitioned_partition,
self.partitioned_partition, **factory_kwargs))
**factory_kwargs))
else:
self.gamma = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.beta = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self._set_tensor_parallel_attribute() self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self): def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.gamma) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute(self.beta) set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
with torch.no_grad(): with torch.no_grad():
@ -233,16 +237,12 @@ class LayerNorm2p5D(ParallelLayer):
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = _LayerNorm_2p5D.apply(x, E_x, Var_x, self.normalized_shape, output = _LayerNorm_2p5D.apply(x, E_x, Var_x, self.normalized_shape,
ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_ROW)
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP)
bias = Add_Bias_2p5D.apply( bias = Add_Bias_2p5D.apply(
None, self.beta, self.partitioned_partition, None, self.beta, self.partitioned_partition,
self.tesseract_dim, self.tesseract_dep, self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank, self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL, ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True, True,
self.data_parallel_rank, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_rank,
@ -251,11 +251,9 @@ class LayerNorm2p5D(ParallelLayer):
) )
scale = Add_Bias_2p5D.apply( scale = Add_Bias_2p5D.apply(
None, self.gamma, self.partitioned_partition, None, self.gamma, self.partitioned_partition,
self.tesseract_dim, self.tesseract_dep, self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank, self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL, ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True, True,
self.data_parallel_rank, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_rank,

View File

@ -1,21 +1,223 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import Any, Tuple from typing import Any, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.communication import all_gather, reduce_scatter, scatter from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import empty_cache, get_current_device
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
class linear_3d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
assert input_.shape[-1] == weight.shape[0], \
'Invalid shapes: input = {}, weight = {}.'.format(input_.shape, weight.shape)
ctx.use_bias = bias is not None
input_ = all_gather(input_, input_dim, input_parallel_mode)
input_ = torch.cat(input_, dim=input_dim)
# weight = all_gather(weight, weight_dim, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, output_dim, output_parallel_mode)
if bias is not None:
# ranks_in_group = gpc.get_ranks_in_group(output_parallel_mode)
# src_rank = ranks_in_group[gpc.get_local_rank(input_parallel_mode)]
# dist.broadcast(bias,
# src=src_rank,
# group=gpc.get_group(output_parallel_mode))
# bias = all_gather(bias, -1, weight_parallel_mode)
output += bias
# ctx.src_rank = src_rank
# ctx.save_for_backward(input_, weight)
# output = torch.matmul(input_, weight)
# dist.all_reduce(output, group=gpc.get_group(output_parallel_mode))
# output += bias
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
ctx.input_dim = input_dim
ctx.weight_dim = weight_dim
ctx.output_dim = output_dim
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
# input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
# dist.all_reduce(input_grad,
# group=gpc.get_group(ctx.input_parallel_mode))
# weight_grad = torch.matmul(
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
# output_grad.reshape(-1, output_grad.shape[-1]))
# dist.all_reduce(weight_grad,
# group=gpc.get_group(ctx.weight_parallel_mode))
# bias_grad = torch.sum(output_grad,
# dim=tuple(
# range(len(output_grad.shape))[:-1]))
# bias_grad = reduce_scatter(bias_grad, -1,
# ctx.weight_parallel_mode)
# dist.reduce(bias_grad,
# dst=ctx.src_rank,
# group=gpc.get_group(ctx.output_parallel_mode))
# if gpc.get_local_rank(
# ctx.output_parallel_mode) != gpc.get_local_rank(
# ctx.input_parallel_mode):
# bias_grad = None
# input_ = all_gather(input_, ctx.input_dim, ctx.input_parallel_mode)
# weight = all_gather(weight, ctx.weight_dim,
# ctx.weight_parallel_mode)
output_grad = all_gather(output_grad, ctx.output_dim,
ctx.output_parallel_mode)
output_grad = torch.cat(output_grad, dim=ctx.output_dim)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, ctx.input_dim,
ctx.input_parallel_mode,
async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
output_grad.reshape(-1, output_grad.shape[-1]))
# weight_grad = torch.matmul(
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
# output_grad.reshape(-1, output_grad.shape[-1]))
# weight_grad = reduce_scatter(weight_grad, ctx.weight_dim,
# ctx.weight_parallel_mode)
if ctx.use_bias:
bias_grad = torch.sum(output_grad,
dim=tuple(
range(len(output_grad.shape))[:-1]))
# bias_grad =all_reduce(bias_grad, ctx.output_parallel_mode)
# dist.all_reduce(bias_grad,
# group=gpc.get_group(ctx.weight_parallel_mode))
weight_grad = torch.cat([weight_grad, torch.unsqueeze(bias_grad, dim=0)])
weight_grad, weight_op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
input_op.wait()
weight_op.wait()
if ctx.use_bias:
bias_grad = weight_grad[-1]
weight_grad = weight_grad[:-1]
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
class layer_norm_3d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, weight: Tensor, bias: Tensor,
normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
# mean = torch.sum(input_, dim=-1)
# dist.all_reduce(mean, group=gpc.get_group(output_parallel_mode))
# mean /= normalized_shape
# mu = input_ - mean
# var = torch.sum(torch.pow(mu, 2), dim=-1)
# dist.all_reduce(var, group=gpc.get_group(output_parallel_mode))
# var /= normalized_shape
# std_dev = torch.sqrt(var + eps)
# ctx.save_for_backward(input_, mu, std_dev, weight)
# output = weight * mu / std_dev + bias
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True),
output_parallel_mode) / normalized_shape
mu = input_ - mean
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True),
output_parallel_mode) / normalized_shape
sigma = torch.sqrt(var + eps)
# ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
# src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
# transforms = torch.stack([weight, bias]).contiguous()
# dist.broadcast(transforms,
# src=src_rank,
# group=gpc.get_group(input_parallel_mode))
# transforms = all_gather(transforms, -1, weight_parallel_mode)
# weight, bias = transforms[0], transforms[1]
ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma
output = weight * z + bias
# ctx.src_rank = src_rank
ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors
with torch.no_grad():
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
grads = torch.stack([bias_grad, weight_grad]).contiguous()
grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1]))
grads = all_reduce(grads, ctx.weight_parallel_mode)
grads = all_reduce(grads, ctx.input_parallel_mode)
bias_grad, weight_grad = grads[0], grads[1]
# grads = reduce_scatter(grads, -1, ctx.weight_parallel_mode)
# dist.reduce(grads,
# dst=ctx.src_rank,
# group=gpc.get_group(ctx.input_parallel_mode))
# if gpc.get_local_rank(
# ctx.input_parallel_mode) == gpc.get_local_rank(
# ctx.output_parallel_mode):
# bias_grad, weight_grad = grads[0], grads[1]
# else:
# bias_grad, weight_grad = None, None
dz = output_grad * weight
dvar = dz * mu * (-0.5) * sigma**(-3)
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None
class Matmul_AB_3D(torch.autograd.Function): class Matmul_AB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB` """Matrix multiplication for :math:`C = AB`
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
@ -29,7 +231,6 @@ class Matmul_AB_3D(torch.autograd.Function):
# A: [m/q^2, n, k/q] # A: [m/q^2, n, k/q]
# B: [k/q, h/q^2] # B: [k/q, h/q^2]
# C: [m/q^2, n, h/q] # C: [m/q^2, n, h/q]
empty_cache()
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
assert A.shape[-1] == B.shape[0], \ assert A.shape[-1] == B.shape[0], \
@ -52,6 +253,7 @@ class Matmul_AB_3D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors A, B = ctx.saved_tensors
with torch.no_grad(): with torch.no_grad():
@ -72,6 +274,7 @@ class Matmul_ABT_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T` """Matrix multiplication for :math:`C = AB^T`
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
@ -85,7 +288,6 @@ class Matmul_ABT_3D(torch.autograd.Function):
# A: [m/q^2, n, h/q] # A: [m/q^2, n, h/q]
# B: [k/q, h/q^2] # B: [k/q, h/q^2]
# C: [m/q^2, n, k/q] # C: [m/q^2, n, k/q]
empty_cache()
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode) A_temp = all_gather(A, input_dim, input_parallel_mode)
@ -105,6 +307,7 @@ class Matmul_ABT_3D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors A, B = ctx.saved_tensors
with torch.no_grad(): with torch.no_grad():
@ -125,6 +328,7 @@ class Matmul_ATB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB` """Matrix multiplication for :math:`C = A^TB`
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
@ -138,7 +342,6 @@ class Matmul_ATB_3D(torch.autograd.Function):
# A: [m/q^2, n, k/q] # A: [m/q^2, n, k/q]
# B: [m/q^2, n, h/q] # B: [m/q^2, n, h/q]
# C: [k/q, h/q^2] # C: [k/q, h/q^2]
empty_cache()
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode) A_temp = all_gather(A, input_dim, input_parallel_mode)
@ -160,6 +363,7 @@ class Matmul_ATB_3D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors A, B = ctx.saved_tensors
with torch.no_grad(): with torch.no_grad():
@ -180,6 +384,7 @@ class Add_3D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b` """Matrix add bias: :math:`C = A + b`
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int, def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode, input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
@ -206,6 +411,7 @@ class Add_3D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q] # output_grad: [m/q^2, n, h/q]
with torch.no_grad(): with torch.no_grad():
@ -217,8 +423,8 @@ class Add_3D(torch.autograd.Function):
dst=ctx.src_rank, dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode)) group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank( if gpc.get_local_rank(
ctx.A_group_parallel_mode) != gpc.get_local_rank( ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode): ctx.C_group_parallel_mode):
bias_grad = None bias_grad = None
return output_grad, bias_grad, None, None, None, None return output_grad, bias_grad, None, None, None, None
@ -227,6 +433,7 @@ class Mul_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A * b` """Matrix multiplication for :math:`C = A * b`
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int, def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode, input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
@ -243,7 +450,7 @@ class Mul_3D(torch.autograd.Function):
# [h/q] # [h/q]
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode) bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
empty_cache() # empty_cache()
ctx.save_for_backward(input_, bias_temp) ctx.save_for_backward(input_, bias_temp)
out = torch.mul(input_, bias_temp) out = torch.mul(input_, bias_temp)
@ -257,6 +464,7 @@ class Mul_3D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q] # output_grad: [m/q^2, n, h/q]
with torch.no_grad(): with torch.no_grad():
@ -272,8 +480,8 @@ class Mul_3D(torch.autograd.Function):
dst=ctx.src_rank, dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode)) group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank( if gpc.get_local_rank(
ctx.A_group_parallel_mode) != gpc.get_local_rank( ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode): ctx.C_group_parallel_mode):
bias_grad = None bias_grad = None
return input_grad, bias_grad, None, None, None, None return input_grad, bias_grad, None, None, None, None
@ -282,6 +490,7 @@ class Sum_3D(torch.autograd.Function):
"""Compute the sum of input tensors """Compute the sum of input tensors
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
input_: Tensor, input_: Tensor,
dim: int, dim: int,
@ -299,6 +508,7 @@ class Sum_3D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
with torch.no_grad(): with torch.no_grad():
output_grad = output_grad.contiguous() output_grad = output_grad.contiguous()
@ -315,35 +525,39 @@ class Reduce_3D(torch.autograd.Function):
"""Reduce input tensors """Reduce input tensors
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, depth: int, def forward(ctx: Any, input_: Tensor, depth: int,
parallel_mode: ParallelMode) -> Tensor: parallel_mode: ParallelMode) -> Tensor:
dist.all_reduce(input_, group=gpc.get_group(parallel_mode)) dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_.clone() return input_.clone()
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
return output_grad, None, None return output_grad, None, None
class Slice_3D(torch.autograd.Function): # class Slice_3D(torch.autograd.Function):
"""Slice input tensor # """Slice input tensor
""" # """
@staticmethod # @staticmethod
def forward(ctx: Any, input_: Tensor, dim: int, depth: int, # @custom_fwd(cast_inputs=torch.float16)
parallel_mode: ParallelMode) -> Tensor: # def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
rank = gpc.get_local_rank(parallel_mode) # parallel_mode: ParallelMode) -> Tensor:
out = torch.chunk(input_, depth, dim=dim)[rank].contiguous() # rank = gpc.get_local_rank(parallel_mode)
# out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
ctx.depth = depth # ctx.depth = depth
ctx.parallel_mode = parallel_mode # ctx.parallel_mode = parallel_mode
ctx.dim = dim # ctx.dim = dim
ctx.input_shape = input_.shape # ctx.input_shape = input_.shape
return out # return out
@staticmethod # @staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: # @custom_bwd
with torch.no_grad(): # def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode) # with torch.no_grad():
input_grad.reshape(ctx.input_shape) # input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
return input_grad, None, None, None # input_grad.reshape(ctx.input_shape)
# return input_grad, None, None, None

View File

@ -3,7 +3,8 @@
import os import os
from colossalai.constants import DEPTH_3D from colossalai.constants import (DEPTH_3D, INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from torch import Tensor from torch import Tensor
@ -23,6 +24,10 @@ def get_depth_from_env() -> int:
) )
def get_parallel_mode_from_env(group):
return getattr(ParallelMode, os.environ[group])
def get_last_group(a, b): def get_last_group(a, b):
mapping = { mapping = {
ParallelMode.PARALLEL_3D_INPUT: 'A', ParallelMode.PARALLEL_3D_INPUT: 'A',
@ -41,6 +46,11 @@ def get_last_group(a, b):
return ParallelMode.PARALLEL_3D_OUTPUT return ParallelMode.PARALLEL_3D_OUTPUT
def swap_in_out_group():
os.environ[INPUT_GROUP_3D], os.environ[OUTPUT_GROUP_3D] = \
os.environ[OUTPUT_GROUP_3D], os.environ[INPUT_GROUP_3D]
def dbg_check_shape(tensor: Tensor, shape: tuple): def dbg_check_shape(tensor: Tensor, shape: tuple):
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
if rank == 0: if rank == 0:

View File

@ -1,17 +1,20 @@
import math import math
from typing import Tuple import os
from typing import Tuple, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.nn.init import init_bias_, init_weight_
from colossalai.utils import checkpoint, get_current_device from colossalai.utils import checkpoint, get_current_device
from torch import Tensor, dtype, nn from torch import Tensor, dtype, nn
from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute_by_size, to_2tuple
from ..vanilla_vision_transformer.layers import to_2tuple from ._utils import get_depth_from_env, get_parallel_mode_from_env, get_last_group
from ._utils import get_depth_from_env
from .layers import Linear3D from .layers import Linear3D
@ -32,34 +35,42 @@ class ViTPatchEmbedding3D(nn.Module):
:param flatten: whether to flatten output tensor, defaults to True :param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional :type flatten: bool, optional
""" """
def __init__(self, def __init__(self,
img_size: int, img_size: int,
patch_size: int, patch_size: int,
in_chans: int, in_chans: int,
embed_size: int, embed_size: int,
drop_prob: float, drop_prob: float,
flatten: bool = True): flatten: bool = True,
init_method: str = 'torch'):
super().__init__() super().__init__()
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
self.img_size = img_size self.img_size = img_size
self.patch_size = patch_size self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1]) img_size[1] // patch_size[1])
self.in_chans = in_chans
self.embed_size = embed_size self.embed_size = embed_size
self.embed_size_per_partition = divide(self.embed_size, self.depth) self.embed_size_per_partition = divide(self.embed_size, self.depth)
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten self.flatten = flatten
self.init_weight = 'torch'
self.init_bias = 'torch'
if init_method == 'jax':
self.init_weight = 'jax_embed'
self.init_bias = 'zero'
with seed(ParallelMode.TENSOR): self.proj = nn.Conv2d(self.in_chans,
self.proj = nn.Conv2d(in_chans, self.embed_size_per_partition,
self.embed_size_per_partition, kernel_size=patch_size,
kernel_size=patch_size, stride=patch_size)
stride=patch_size)
self.cls_token = nn.Parameter( self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.embed_size_per_partition)) torch.zeros(1, 1, self.embed_size_per_partition))
@ -68,23 +79,26 @@ class ViTPatchEmbedding3D(nn.Module):
self.embed_size_per_partition)) self.embed_size_per_partition))
self.pos_drop = nn.Dropout(drop_prob) self.pos_drop = nn.Dropout(drop_prob)
self._sync_parameters() self.reset_parameters(self.init_weight, self.init_bias)
self.proj.weight.register_hook(self._sync_grad_hook) self._set_tensor_parallel_attributes()
self.proj.bias.register_hook(self._sync_grad_hook)
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self): def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.proj.weight) set_tensor_parallel_attribute_by_size(self.proj.weight, self.in_chans * self.embed_size * self.num_patches)
set_tensor_parallel_attribute(self.proj.bias) set_tensor_parallel_attribute_by_size(self.proj.bias, self.embed_size)
set_tensor_parallel_attribute(self.cls_token) set_tensor_parallel_attribute_by_size(self.cls_token, 1 * 1 * self.embed_size)
set_tensor_parallel_attribute(self.pos_embed) set_tensor_parallel_attribute_by_size(self.pos_embed, 1 * (self.num_patches + 1) * self.embed_size)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: def reset_parameters(self, init_weight, init_bias):
return self.input_parallel_mode, self.weight_parallel_mode fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.proj.weight)
# std = math.sqrt(1.0 / fan_in)
# nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
# nn.init.zeros_(self.proj.bias)
if init_weight != 'torch':
init_weight_(self.proj.weight, fan_in, init_method=init_weight)
init_bias_(self.pos_embed, fan_in, init_method=init_weight)
if init_bias != 'torch':
init_bias_(self.proj.bias, fan_in, init_method=init_bias)
def _sync_parameters(self):
self.to(get_current_device()) self.to(get_current_device())
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
dist.broadcast(self.proj.weight, dist.broadcast(self.proj.weight,
@ -100,10 +114,11 @@ class ViTPatchEmbedding3D(nn.Module):
dist.broadcast(self.proj.bias, dist.broadcast(self.proj.bias,
src=input_src_rank, src=input_src_rank,
group=gpc.get_group(self.input_parallel_mode)) group=gpc.get_group(self.input_parallel_mode))
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias) self.proj.weight.register_hook(self._sync_grad_hook)
set_tensor_parallel_attribute(self.cls_token) self.proj.bias.register_hook(self._sync_grad_hook)
set_tensor_parallel_attribute(self.pos_embed) self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
def _sync_grad_hook(self, grad) -> None: def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode)) dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode))
@ -111,6 +126,12 @@ class ViTPatchEmbedding3D(nn.Module):
return grad return grad
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# split a partition from inputs
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.weight_parallel_mode)].contiguous()
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.input_parallel_mode)].contiguous()
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
@ -118,12 +139,6 @@ class ViTPatchEmbedding3D(nn.Module):
if self.flatten: if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
# split a partition from embedded states
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.weight_parallel_mode)].contiguous()
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.input_parallel_mode)].contiguous()
# add cls token & pos embedding # add cls token & pos embedding
# [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q] # [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q]
cls_token = self.cls_token.expand(x.shape[0], -1, -1) cls_token = self.cls_token.expand(x.shape[0], -1, -1)
@ -158,6 +173,7 @@ class ViTSelfAttention3D(nn.Module):
:param bias: whether to add bias, defaults to True :param bias: whether to add bias, defaults to True
:type bias: bool, optional :type bias: bool, optional
""" """
def __init__(self, def __init__(self,
hidden_size: int, hidden_size: int,
num_attention_heads: int, num_attention_heads: int,
@ -165,41 +181,52 @@ class ViTSelfAttention3D(nn.Module):
hidden_dropout_prob: float, hidden_dropout_prob: float,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
checkpoint: bool = False): checkpoint: bool = False,
init_method: str = 'torch'):
super().__init__() super().__init__()
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT # self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT # self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT # self.output_parallel_mode = get_last_group(self.input_parallel_mode,
# self.weight_parallel_mode)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.depth) self.num_attention_heads = divide(num_attention_heads, self.depth)
self.attention_head_size = divide(hidden_size, num_attention_heads) self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.init_weight = 'torch'
self.init_bias = 'torch'
if init_method == 'jax':
self.init_weight = 'jax'
self.init_bias = 'zero'
self.query_key_value = Linear3D(self.hidden_size, self.query_key_value = Linear3D(self.hidden_size,
3 * self.hidden_size, 3 * self.hidden_size,
self.input_parallel_mode, # self.input_parallel_mode,
self.weight_parallel_mode, # self.weight_parallel_mode,
dtype=dtype, dtype=dtype,
bias=bias) bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.attention_dropout = nn.Dropout(attention_probs_dropout_prob) self.attention_dropout = nn.Dropout(attention_probs_dropout_prob)
self.dense = Linear3D(self.hidden_size, self.dense = Linear3D(self.hidden_size,
self.hidden_size, self.hidden_size,
self.output_parallel_mode, # self.output_parallel_mode,
self.weight_parallel_mode, # self.weight_parallel_mode,
dtype=dtype, dtype=dtype,
bias=bias) bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.dropout = nn.Dropout(hidden_dropout_prob) self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: # def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode # return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor: def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states) query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \ new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size) (self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape) query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3)) query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(query_key_value, query_layer, key_layer, value_layer = torch.chunk(query_key_value,
@ -259,6 +286,7 @@ class ViTMLP3D(nn.Module):
:param bias: whether to add bias, defaults to True :param bias: whether to add bias, defaults to True
:type bias: bool, optional :type bias: bool, optional
""" """
def __init__(self, def __init__(self,
hidden_size: int, hidden_size: int,
mlp_ratio: int, mlp_ratio: int,
@ -266,33 +294,41 @@ class ViTMLP3D(nn.Module):
hidden_act: str = 'gelu', hidden_act: str = 'gelu',
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
checkpoint: bool = False): checkpoint: bool = False,
init_method: str = 'torch'):
super().__init__() super().__init__()
self.depth = get_depth_from_env() # self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT # self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT # self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT # self.output_parallel_mode = get_last_group(self.input_parallel_mode,
# self.weight_parallel_mode)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.init_weight = init_method
self.init_bias = init_method
self.dense_1 = Linear3D(self.hidden_size, self.dense_1 = Linear3D(self.hidden_size,
self.mlp_ratio * self.hidden_size, self.mlp_ratio * self.hidden_size,
self.input_parallel_mode, # self.input_parallel_mode,
self.weight_parallel_mode, # self.weight_parallel_mode,
dtype=dtype, dtype=dtype,
bias=bias) bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.activation_func = ACT2FN[hidden_act] self.activation_func = ACT2FN[hidden_act]
self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size, self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size,
self.hidden_size, self.hidden_size,
self.output_parallel_mode, # self.output_parallel_mode,
self.weight_parallel_mode, # self.weight_parallel_mode,
dtype=dtype, dtype=dtype,
bias=bias) bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.dropout = nn.Dropout(hidden_dropout_prob) self.dropout = nn.Dropout(hidden_dropout_prob)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: # def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode # return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor: def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states) intermediate_output = self.dense_1(hidden_states)
@ -331,37 +367,46 @@ class ViTHead3D(nn.Module):
:param bias: whether to add bias, defaults to True :param bias: whether to add bias, defaults to True
:type bias: bool, optional :type bias: bool, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True): bias: bool = True,
init_method: str = 'torch'):
super().__init__() super().__init__()
self.depth = get_depth_from_env() # self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT # self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT # self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT # self.output_parallel_mode = get_last_group(self.input_parallel_mode,
# self.weight_parallel_mode)
self.in_features = in_features self.in_features = in_features
self.num_classes = num_classes self.num_classes = num_classes
out_features = math.ceil(self.num_classes / # out_features = math.ceil(self.num_classes /
(self.depth**2)) * (self.depth**2) # (self.depth**2)) * (self.depth**2)
self.num_classes_per_partition = divide(self.num_classes, self.depth) # self.num_classes_per_partition = divide(self.num_classes, self.depth)
self.linear = Linear3D(self.in_features, self.init_weight = 'torch'
out_features, self.init_bias = 'torch'
self.input_parallel_mode, if init_method == 'jax':
self.weight_parallel_mode, self.init_weight = 'zero'
dtype=dtype, self.init_bias = 'zero'
bias=bias)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: self.linear = Linear3D(self.in_features,
return self.linear.groups_for_next_layer() self.num_classes,
# self.input_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# [b/q^2, s, h/q] --> [b/q^2, h/q] # [b/q^2, s, h/q] --> [b/q^2, h/q]
x = x[:, 0] x = x[:, 0]
# [b/q^2, h/q] --> [b/q^2, c/q] # [b/q^2, h/q] --> [b/q^2, c/q]
x = self.linear(x) x = self.linear(x)
return x[:, :self.num_classes_per_partition] # return x[:, :self.num_classes_per_partition]
return x
def extra_repr(self): def extra_repr(self):
return 'in_features={}, num_classes={}'.format(self.in_features, return 'in_features={}, num_classes={}'.format(self.in_features,

View File

@ -2,19 +2,28 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import os
from typing import Tuple from typing import Tuple
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.nn.init import init_bias_, init_weight_
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import Tensor, dtype from torch import Tensor, dtype
from torch.nn import Parameter from torch.nn import Parameter
from torch.nn import init as init
from .._common_utils import divide, set_tensor_parallel_attribute from .._common_utils import divide, set_tensor_parallel_attribute_by_size
from ._operation import Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D from ._operation import (Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D, layer_norm_3d,
from ._utils import get_depth_from_env, get_last_group linear_3d)
from ._utils import (get_depth_from_env, get_last_group,
get_parallel_mode_from_env, swap_in_out_group)
@LAYERS.register_module @LAYERS.register_module
@ -22,20 +31,19 @@ class LayerNorm3D(nn.Module):
def __init__( def __init__(
self, self,
normalized_shape: int, normalized_shape: int,
input_parallel_mode: ParallelMode, # input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, # weight_parallel_mode: ParallelMode,
eps: float = 1e-12, eps: float = 1e-12,
dtype: dtype = None, dtype: dtype = None,
): ):
super().__init__() super().__init__()
self.input_parallel_mode = input_parallel_mode self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = weight_parallel_mode self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode) self.weight_parallel_mode)
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.normalized_shape = normalized_shape self.normalized_shape = normalized_shape
self.normalized_shape_per_partition = divide(normalized_shape, self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
self.depth**2)
self.weight = Parameter( self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, torch.ones(self.normalized_shape_per_partition,
@ -49,37 +57,40 @@ class LayerNorm3D(nn.Module):
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight) set_tensor_parallel_attribute_by_size(self.weight, self.normalized_shape)
set_tensor_parallel_attribute(self.bias) set_tensor_parallel_attribute_by_size(self.bias, self.normalized_shape)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
def reset_parameters(self): def reset_parameters(self):
nn.init.zeros_(self.bias) init.zeros_(self.bias)
nn.init.ones_(self.weight) init.ones_(self.weight)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
'''x = weight * (x - mean) / sqrt(var + eps) + bias''' # '''x = weight * (x - mean) / sqrt(var + eps) + bias'''
# input: [m/q^2, n, h/q] # # input: [m/q^2, n, h/q]
# [m/q^2, n, 1] # # [m/q^2, n, 1]
mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode, # mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
True) / self.normalized_shape # True) / self.normalized_shape
# [m/q^2, n, 1] # # [m/q^2, n, 1]
var = (input_ - mean).pow(2) # var = (input_ - mean).pow(2)
var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode, # var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
True) / self.normalized_shape # True) / self.normalized_shape
output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon) # output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
output = Mul_3D.apply(output, self.weight, self.depth, # output = Mul_3D.apply(output, self.weight, self.depth,
self.input_parallel_mode, # self.input_parallel_mode,
self.weight_parallel_mode, # self.weight_parallel_mode,
self.output_parallel_mode) # self.output_parallel_mode)
output = Add_3D.apply(output, self.bias, self.depth, # output = Add_3D.apply(output, self.bias, self.depth,
self.input_parallel_mode, # self.input_parallel_mode,
self.weight_parallel_mode, # self.weight_parallel_mode,
self.output_parallel_mode) # self.output_parallel_mode)
return output # return output
return layer_norm_3d.apply(input_, self.weight, self.bias,
self.normalized_shape,
self.variance_epsilon,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
def extra_repr(self): def extra_repr(self):
return '{}, eps={}'.format(self.normalized_shape, return '{}, eps={}'.format(self.normalized_shape,
@ -88,33 +99,36 @@ class LayerNorm3D(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class Linear3D(nn.Module): class Linear3D(nn.Module):
def __init__(self, def __init__(
in_features: int, self,
out_features: int, in_features: int,
input_parallel_mode: ParallelMode, out_features: int,
weight_parallel_mode: ParallelMode, # input_parallel_mode: ParallelMode,
bias: bool = True, # weight_parallel_mode: ParallelMode,
dtype: dtype = None): bias: bool = True,
dtype: dtype = None,
init_weight: str = 'torch',
init_bias: str = 'torch'):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.input_parallel_mode = input_parallel_mode self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = weight_parallel_mode self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode) self.weight_parallel_mode)
self.with_bias = bias # self.with_bias = bias
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth) self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth**2) self.out_features_per_partition = divide(out_features, self.depth)
# [k/q, h/q^2] # [k/q, h/q]
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.in_features_per_partition, torch.empty(self.in_features_per_partition,
self.out_features_per_partition, self.out_features_per_partition,
device=get_current_device(), device=get_current_device(),
dtype=dtype)) dtype=dtype))
# [h/q^2] # [h/q]
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.zeros(self.out_features_per_partition, torch.zeros(self.out_features_per_partition,
@ -123,49 +137,54 @@ class Linear3D(nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
swap_in_out_group()
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight) set_tensor_parallel_attribute_by_size(self.weight, self.in_features * self.out_features)
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute(self.bias) set_tensor_parallel_attribute_by_size(self.bias, self.out_features)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: def reset_parameters(self, init_weight, init_bias) -> None:
return self.output_parallel_mode, self.weight_parallel_mode
def reset_parameters(self):
# setting # setting
fan_in = self.in_features fan_in, fan_out = self.in_features, self.out_features
a = math.sqrt(5) weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
nonlinearity = 'leaky_relu' output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
# init weight # init weight
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) init_weight_(self.weight, fan_in, fan_out, init_method=init_weight)
bound = math.sqrt(3.0) * std dist.broadcast(self.weight,
with seed(ParallelMode.TENSOR): src=weight_src_rank,
nn.init.uniform_(self.weight, -bound, bound) group=gpc.get_group(self.weight_parallel_mode))
# init bias # init bias
if self.with_bias: if self.bias is not None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 init_bias_(self.bias, fan_in, init_method=init_bias)
with seed(ParallelMode.TENSOR): dist.broadcast(self.bias,
nn.init.uniform_(self.bias, -bound, bound) src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
dist.broadcast(self.bias,
src=output_src_rank,
group=gpc.get_group(self.output_parallel_mode))
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
# input: [m/q^2, n, k/q] # # input: [m/q^2, n, k/q]
# output: [m/q^2, n, h/q] # # output: [m/q^2, n, h/q]
output = Matmul_AB_3D.apply(input_, self.weight, self.depth, # output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
self.input_parallel_mode, # self.input_parallel_mode,
self.weight_parallel_mode, # self.weight_parallel_mode,
self.output_parallel_mode) # self.output_parallel_mode)
if self.with_bias: # if self.bias is not None:
output = Add_3D.apply(output, self.bias, self.depth, # output = Add_3D.apply(output, self.bias, self.depth,
self.output_parallel_mode, # self.output_parallel_mode,
self.weight_parallel_mode, # self.weight_parallel_mode,
self.input_parallel_mode) # self.input_parallel_mode)
return output # return output
return linear_3d.apply(input_, self.weight, self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
def extra_repr(self): def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format( return 'in_features={}, out_features={}, bias={}'.format(

View File

@ -1,3 +0,0 @@
from .layers import ViTBlock
__all__ = ['ViTBlock']

View File

@ -1,59 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
@LAYERS.register_module
class ViTBlock(nn.Module):
"""Vision Transformer block
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
# x_ = x
# x_ = self.norm1(x_)
# if self.checkpoint:
# x_ = checkpoint(self.attn, x_)
# else:
# x_ = self.attn(x_)
# x_ = self.drop_path(x_)
# x = x + x_
#
# x_ = x
# x_ = self.norm2(x_)
# if self.checkpoint:
# x_ = checkpoint(self.mlp, x_)
# else:
# x_ = self.mlp(x_)
# x_ = self.drop_path(x_)
# x = x + x_
return x

View File

@ -1,5 +0,0 @@
from .basic_block import ResNetBasicBlock
from .bottleneck import ResNetBottleneck
from .reslayer import ResLayer
__all__ = ['ResLayer', 'ResNetBottleneck', 'ResNetBasicBlock']

View File

@ -1,64 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3
@LAYERS.register_module
class ResNetBasicBlock(nn.Module):
"""Basic ResNet block
"""
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out

View File

@ -1,69 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3, conv1x1
@LAYERS.register_module
class ResNetBottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out

View File

@ -1,15 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

View File

@ -1,63 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.registry import LAYERS
from .conv import conv1x1
@LAYERS.register_module
class ResLayer(nn.Module):
def __init__(self,
block_type: str,
norm_layer_type: str,
inplanes: int,
planes: int,
blocks: int,
groups: int,
base_width: int,
stride: int = 1,
dilation: int = 1,
dilate: bool = False,
):
super().__init__()
self.block = LAYERS.get_module(block_type)
self.norm_layer = LAYERS.get_module(norm_layer_type)
self.inplanes = inplanes
self.planes = planes
self.blocks = blocks
self.groups = groups
self.dilation = dilation
self.base_width = base_width
self.dilate = dilate
self.stride = stride
self.layer = self._make_layer()
def _make_layer(self):
norm_layer = self.norm_layer
downsample = None
previous_dilation = self.dilation
if self.dilate:
self.dilation *= self.stride
self.stride = 1
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
norm_layer(self.planes * self.block.expansion),
)
layers = []
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = self.planes * self.block.expansion
for _ in range(1, self.blocks):
layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)

View File

@ -1,7 +0,0 @@
from .layers import (VanillaViTBlock, VanillaViTMLP, VanillaViTPatchEmbedding,
VanillaViTAttention, VanillaViTDropPath, VanillaViTHead)
__all__ = [
'VanillaViTBlock', 'VanillaViTMLP', 'VanillaViTPatchEmbedding',
'VanillaViTAttention', 'VanillaViTDropPath', 'VanillaViTHead'
]

View File

@ -1,4 +1,3 @@
from .base_loss import BaseLoss
from .cross_entropy_2d import CrossEntropyLoss2D from .cross_entropy_2d import CrossEntropyLoss2D
from .cross_entropy_2p5d import CrossEntropyLoss2p5D from .cross_entropy_2p5d import CrossEntropyLoss2p5D
from .cross_entropy_3d import CrossEntropyLoss3D from .cross_entropy_3d import CrossEntropyLoss3D

View File

@ -1,13 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseLoss(ABC):
"""Absctract loss class
"""
@abstractmethod
def calc_loss(self, *args, **kwargs):
pass

View File

@ -1,120 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_1d._utils import vocab_range_from_per_partition_vocab_size
class _VocabParallelCrossEntropy_1D(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
vocab_start_index, vocab_end_index = vocab_range_from_per_partition_vocab_size(
partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (
1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
class LmLoss1D(_Loss):
def forward(self, lm_logits, lm_labels, loss_mask):
lm_loss = _VocabParallelCrossEntropy_1D.apply(lm_logits, lm_labels)
lm_loss = torch.sum(
lm_loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
return lm_loss
class SopLoss1D(_Loss):
def forward(self, sop_logits, sentence_order):
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
return sop_loss
class BERTDualHeadLoss(_Loss):
def __init__(self):
self.lm_loss = LmLoss1D()
self.sop_loss = SopLoss1D()
def forward(self, lm_logits, sop_logits, lm_labels, loss_mask, sentence_order):
lm_loss = self.lm_loss(lm_logits, lm_labels, loss_mask)
sop_loss = self.sop_loss(sop_logits, sentence_order)
return lm_loss + sop_loss

View File

@ -7,18 +7,18 @@ from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function): class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ### ### Modified based on megatron.mpu.cross_entropy ###
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets): def forward(ctx, logits, targets):
# logits: [b/q, h/q] # logits: [b/q, h/q]
# labels: [b/q] # labels: [b/q]
# loss: [b/q]
# vocab_parallel_logits: [b/q, s, v/q]
# target: [b/q, s]
logits_max = torch.max(logits, dim=-1)[0] logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce( torch.distributed.all_reduce(
logits_max, logits_max,
@ -58,6 +58,7 @@ class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
return loss return loss
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, output_grad): def backward(ctx, output_grad):
# Retreive tensors from the forward path. # Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors softmax, target_mask, masked_target = ctx.saved_tensors
@ -91,12 +92,14 @@ class _ReduceByColumn(torch.autograd.Function):
return input_ return input_
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_): def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group( dist.all_reduce(input_, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL)) ParallelMode.PARALLEL_2D_COL))
return input_ return input_
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output return grad_output

View File

@ -1,32 +1,20 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn.modules.loss import _Loss from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.communication import all_gather
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_3d._operation import Reduce_3D from colossalai.nn.layer.parallel_3d._operation import Reduce_3D
from colossalai.nn.layer.parallel_3d._utils import get_last_group, get_depth_from_env from colossalai.nn.layer.parallel_3d._utils import (get_depth_from_env,
get_last_group,
get_parallel_mode_from_env)
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch.nn.modules.loss import _Loss
def accuracy_3d(output, target, input_parallel_mode, weight_parallel_mode):
depth = get_depth_from_env()
output_parallel_mode = get_last_group(input_parallel_mode,
weight_parallel_mode)
j = gpc.get_local_rank(input_parallel_mode)
i = gpc.get_local_rank(weight_parallel_mode)
target = torch.chunk(target, depth, dim=0)[i]
target = torch.chunk(target, depth, dim=0)[j]
output = all_gather(output, -1, output_parallel_mode)
prediction = torch.argmax(output, dim=-1)
correct = torch.sum(prediction == target)
dist.all_reduce(correct, group=gpc.get_group(input_parallel_mode))
dist.all_reduce(correct, group=gpc.get_group(weight_parallel_mode))
return correct.item()
class _ParallelCrossEntropyLossFunction_3D(torch.autograd.Function): class _ParallelCrossEntropyLossFunction_3D(torch.autograd.Function):
@ -112,16 +100,18 @@ class CrossEntropyLoss3D(_Loss):
:param reduction: whether to average the loss, defaults to True :param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional :type reduction: bool, optional
""" """
def __init__(self, def __init__(
input_parallel_mode, self,
weight_parallel_mode, # input_parallel_mode,
reduction=True): # weight_parallel_mode,
reduction=True,
label_smoothing=0.0):
super().__init__() super().__init__()
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.input_parallel_mode = input_parallel_mode self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = weight_parallel_mode self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(input_parallel_mode, self.output_parallel_mode = get_last_group(self.input_parallel_mode,
weight_parallel_mode) self.weight_parallel_mode)
self.input_rank = gpc.get_local_rank(self.input_parallel_mode) self.input_rank = gpc.get_local_rank(self.input_parallel_mode)
self.weight_rank = gpc.get_local_rank(self.weight_parallel_mode) self.weight_rank = gpc.get_local_rank(self.weight_parallel_mode)
self.reduction_mean = reduction self.reduction_mean = reduction
@ -141,53 +131,53 @@ class CrossEntropyLoss3D(_Loss):
return loss return loss
@LOSSES.register_module # @LOSSES.register_module
class LabelSmoothingCrossEntropy3D(_Loss): # class LabelSmoothingCrossEntropy3D(_Loss):
""" # """
NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy # NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy
:param input_parallel_mode: parallel mode for input tensor # :param input_parallel_mode: parallel mode for input tensor
:type input_parallel_mode: ParallelMode # :type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode for weight # :param weight_parallel_mode: parallel mode for weight
:type weight_parallel_mode: ParallelMode # :type weight_parallel_mode: ParallelMode
:param smoothing: label smoothing value, defaults to 0.1 # :param smoothing: label smoothing value, defaults to 0.1
:type smoothing: float # :type smoothing: float
:param reduction: whether to average the loss, defaults to True # :param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional # :type reduction: bool, optional
""" # """
def __init__(self, # def __init__(self,
input_parallel_mode, # input_parallel_mode,
weight_parallel_mode, # weight_parallel_mode,
smoothing=0.1, # smoothing=0.1,
reduction=True): # reduction=True):
super().__init__() # super().__init__()
assert smoothing < 1.0 # assert smoothing < 1.0
self.smoothing = smoothing # self.smoothing = smoothing
self.confidence = 1. - smoothing # self.confidence = 1. - smoothing
self.depth = get_depth_from_env() # self.depth = get_depth_from_env()
self.input_parallel_mode = input_parallel_mode # self.input_parallel_mode = input_parallel_mode
self.weight_parallel_mode = weight_parallel_mode # self.weight_parallel_mode = weight_parallel_mode
self.output_parallel_mode = get_last_group(input_parallel_mode, # self.output_parallel_mode = get_last_group(input_parallel_mode,
weight_parallel_mode) # weight_parallel_mode)
self.reduction_mean = reduction # self.reduction_mean = reduction
def forward(self, logits, targets): # def forward(self, logits, targets):
# split label partition from the entire batch # # split label partition from the entire batch
j = gpc.get_local_rank(self.input_parallel_mode) # j = gpc.get_local_rank(self.input_parallel_mode)
i = gpc.get_local_rank(self.weight_parallel_mode) # i = gpc.get_local_rank(self.weight_parallel_mode)
targets = torch.chunk(targets, self.depth, dim=0)[i] # targets = torch.chunk(targets, self.depth, dim=0)[i]
targets = torch.chunk(targets, self.depth, dim=0)[j] # targets = torch.chunk(targets, self.depth, dim=0)[j]
exp_logits = torch.exp(logits) # exp_logits = torch.exp(logits)
sum_exp_logits = Sum3D.apply(exp_logits, -1, depth, # sum_exp_logits = Sum3D.apply(exp_logits, -1, depth,
self.output_parallel_mode, False) # self.output_parallel_mode, False)
log_probs = torch.log(sum_exp_logits) - logits # log_probs = torch.log(sum_exp_logits) - logits
nll_loss = _ParallelCrossEntropyLossFunction_3D.apply( # nll_loss = _ParallelCrossEntropyLossFunction_3D.apply(
logits, targets, self.depth, self.output_parallel_mode) # logits, targets, self.depth, self.output_parallel_mode)
smooth_loss = -log_probs.mean(dim=-1) # smooth_loss = -log_probs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss # loss = self.confidence * nll_loss + self.smoothing * smooth_loss
if self.reduction_mean: # if self.reduction_mean:
loss = loss.sum() # loss = loss.sum()
loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode) # loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode) # loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
loss /= batch_size # loss /= batch_size
return loss # return loss

View File

@ -48,8 +48,10 @@ class DelayerScheduler(_LRScheduler):
if self.finished: if self.finished:
if epoch is None: if epoch is None:
self.after_scheduler.step(None) self.after_scheduler.step(None)
self._last_lr = self.after_scheduler.get_last_lr()
else: else:
self.after_scheduler.step(epoch - self.delay_epochs) self.after_scheduler.step(epoch - self.delay_epochs)
self._last_lr = self.after_scheduler.get_last_lr()
else: else:
return super(DelayerScheduler, self).step(epoch) return super(DelayerScheduler, self).step(epoch)
@ -66,6 +68,7 @@ class WarmupScheduler(_LRScheduler):
:param last_epoch: The index of last epoch, defaults to -1 :param last_epoch: The index of last epoch, defaults to -1
:type last_epoch: int, optional :type last_epoch: int, optional
""" """
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
self.warmup_epochs = int(warmup_epochs) self.warmup_epochs = int(warmup_epochs)
self.after_scheduler = after_scheduler self.after_scheduler = after_scheduler
@ -85,8 +88,10 @@ class WarmupScheduler(_LRScheduler):
if self.finished: if self.finished:
if epoch is None: if epoch is None:
self.after_scheduler.step(None) self.after_scheduler.step(None)
self._last_lr = self.after_scheduler.get_last_lr()
else: else:
self.after_scheduler.step(epoch - self.warmup_epochs) self.after_scheduler.step(epoch - self.warmup_epochs)
self._last_lr = self.after_scheduler.get_last_lr()
else: else:
return super().step(epoch) return super().step(epoch)
@ -136,7 +141,9 @@ class WarmupDelayerScheduler(_LRScheduler):
if self.finished: if self.finished:
if epoch is None: if epoch is None:
self.after_scheduler.step(None) self.after_scheduler.step(None)
self._last_lr = self.after_scheduler.get_last_lr()
else: else:
self.after_scheduler.step(epoch - self.warmup_epochs) self.after_scheduler.step(epoch - self.warmup_epochs)
self._last_lr = self.after_scheduler.get_last_lr()
else: else:
return super().step(epoch) return super().step(epoch)

View File

@ -12,7 +12,6 @@ class MultiStepLR(_MultiStepLR):
number of epoch reaches one of the milestones. Notice that such decay can number of epoch reaches one of the milestones. Notice that such decay can
happen simultaneously with other changes to the learning rate from outside happen simultaneously with other changes to the learning rate from outside
this scheduler. When last_epoch=-1, sets initial lr as lr. this scheduler. When last_epoch=-1, sets initial lr as lr.
:param optimizer: Wrapped optimizer :param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps :param total_steps: number of total training steps
@ -34,7 +33,6 @@ class MultiStepLR(_MultiStepLR):
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module
class MultiStepWarmupLR(WarmupScheduler): class MultiStepWarmupLR(WarmupScheduler):
"""Multi-step laerning rate scheduler with warmup. """Multi-step laerning rate scheduler with warmup.
:param optimizer: Wrapped optimizer :param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps :param total_steps: number of total training steps

View File

@ -12,28 +12,21 @@ class OneCycleLR(_OneCycleLR):
than the initial learning rate. than the initial learning rate.
This policy was initially described in the paper `Super-Convergence: This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_. Very Fast Training of Neural Networks Using Large Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every batch. The 1cycle learning rate policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training. `step` should be called after a batch has been used for training.
This scheduler is not chainable. This scheduler is not chainable.
Note also that the total number of steps in the cycle can be determined in one Note also that the total number of steps in the cycle can be determined in one
of two ways (listed in order of precedence): of two ways (listed in order of precedence):
#. A value for total_steps is explicitly provided. #. A value for total_steps is explicitly provided.
#. A number of epochs (epochs) and a number of steps per epoch #. A number of epochs (epochs) and a number of steps per epoch
(steps_per_epoch) are provided. (steps_per_epoch) are provided.
In this case, the number of total steps is inferred by In this case, the number of total steps is inferred by
total_steps = epochs * steps_per_epoch total_steps = epochs * steps_per_epoch
You must either provide a value for total_steps or provide a value for both You must either provide a value for total_steps or provide a value for both
epochs and steps_per_epoch. epochs and steps_per_epoch.
The default behaviour of this scheduler follows the fastai implementation of 1cycle, which The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
claims that "unpublished work has shown even better results by using only two phases". To claims that "unpublished work has shown even better results by using only two phases". To
mimic the behaviour of the original paper instead, set ``three_phase=True``. mimic the behaviour of the original paper instead, set ``three_phase=True``.
:param optimizer: Wrapped optimizer :param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps :param total_steps: number of total training steps
@ -71,7 +64,6 @@ class OneCycleLR(_OneCycleLR):
number of *batches* computed, not the total number of epochs computed. number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning, defaults to -1 When last_epoch=-1, the schedule is started from the beginning, defaults to -1
:type last_epoch: int, optional :type last_epoch: int, optional
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120 https://arxiv.org/abs/1708.07120
""" """

View File

@ -7,7 +7,6 @@ from .delayed import WarmupScheduler
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module
class PolynomialLR(_LRScheduler): class PolynomialLR(_LRScheduler):
"""Polynomial learning rate scheduler. """Polynomial learning rate scheduler.
:param optimizer: Wrapped optimizer :param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps :param total_steps: number of total training steps
@ -43,7 +42,6 @@ class PolynomialLR(_LRScheduler):
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module
class PolynomialWarmupLR(WarmupScheduler): class PolynomialWarmupLR(WarmupScheduler):
"""Polynomial learning rate scheduler with warmup. """Polynomial learning rate scheduler with warmup.
:param optimizer: Wrapped optimizer :param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps :param total_steps: number of total training steps

View File

@ -10,7 +10,6 @@ from colossalai.registry import LR_SCHEDULERS
class LambdaLR(_LambdaLR): class LambdaLR(_LambdaLR):
"""Sets the learning rate of each parameter group to the initial lr """Sets the learning rate of each parameter group to the initial lr
times a given function. When last_epoch=-1, sets initial lr as lr. times a given function. When last_epoch=-1, sets initial lr as lr.
:param optimizer: Wrapped optimizer :param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps :param total_steps: number of total training steps
@ -33,7 +32,6 @@ class LambdaLR(_LambdaLR):
class MultiplicativeLR(_MultiplicativeLR): class MultiplicativeLR(_MultiplicativeLR):
"""Multiply the learning rate of each parameter group by the factor given """Multiply the learning rate of each parameter group by the factor given
in the specified function. When last_epoch=-1, sets initial lr as lr in the specified function. When last_epoch=-1, sets initial lr as lr
:param optimizer: Wrapped optimizer :param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps :param total_steps: number of total training steps
@ -58,7 +56,6 @@ class StepLR(_StepLR):
step_size epochs. Notice that such decay can happen simultaneously with step_size epochs. Notice that such decay can happen simultaneously with
other changes to the learning rate from outside this scheduler. When other changes to the learning rate from outside this scheduler. When
last_epoch=-1, sets initial lr as lr last_epoch=-1, sets initial lr as lr
:param optimizer: Wrapped optimizer :param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps :param total_steps: number of total training steps
@ -82,7 +79,6 @@ class StepLR(_StepLR):
class ExponentialLR(_ExponentialLR): class ExponentialLR(_ExponentialLR):
"""Decays the learning rate of each parameter group by gamma every epoch. """Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr When last_epoch=-1, sets initial lr as lr
:param optimizer: Wrapped optimizer :param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps :param total_steps: number of total training steps

View File

@ -1,3 +1,3 @@
from .base_model import BaseModel from .model_from_config import ModelFromConfig
from .vanilla_resnet import VanillaResNet
from .vision_transformer import * __all__ = ['ModelFromConfig']

View File

@ -8,10 +8,10 @@ import torch.nn as nn
from colossalai.builder import build_layer from colossalai.builder import build_layer
class BaseModel(nn.Module, ABC): class ModelFromConfig(nn.Module, ABC):
def __init__(self): def __init__(self):
super(BaseModel, self).__init__() super(ModelFromConfig, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers_cfg = [] self.layers_cfg = []
@ -32,7 +32,6 @@ class BaseModel(nn.Module, ABC):
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""Use this function to override the state dict for """Use this function to override the state dict for
saving checkpoints.""" saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars) return self.state_dict(destination, prefix, keep_vars)

View File

@ -1,3 +0,0 @@
from .resnet import VanillaResNet
__all__ = ['VanillaResNet']

View File

@ -1,163 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from colossalai.registry import MODELS
from ..base_model import BaseModel
@MODELS.register_module
class VanillaResNet(BaseModel):
"""ResNet from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""
def __init__(
self,
num_cls: int,
block_type: str,
layers: List[int],
norm_layer_type: str = 'BatchNorm2d',
in_channels: int = 3,
groups: int = 1,
width_per_group: int = 64,
zero_init_residual: bool = False,
replace_stride_with_dilation: Optional[List[bool]] = None,
dilations=(1, 1, 1, 1)
) -> None:
super().__init__()
self.inplanes = 64
self.zero_init_residual = zero_init_residual
self.blocks = layers
self.block_expansion = LAYERS.get_module(block_type).expansion
self.dilations = dilations
self.reslayer_common_cfg = dict(
type='ResLayer',
block_type=block_type,
norm_layer_type=norm_layer_type,
groups=groups,
base_width=width_per_group
)
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.layers_cfg = [
# conv1
dict(type='Conv2d',
in_channels=in_channels,
out_channels=self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False),
# bn1
dict(
type=norm_layer_type,
num_features=self.inplanes
),
# relu
dict(
type='ReLU',
inplace=True
),
# maxpool
dict(
type='MaxPool2d',
kernel_size=3,
stride=2,
padding=1
),
# layer 1
dict(
inplanes=self.inplanes,
planes=64,
blocks=self.blocks[0],
dilation=self.dilations[0],
**self.reslayer_common_cfg
),
# layer 2
dict(
inplanes=64 * self.block_expansion,
planes=128,
blocks=self.blocks[1],
stride=2,
dilate=replace_stride_with_dilation[0],
dilation=self.dilations[1],
**self.reslayer_common_cfg
),
# layer 3
dict(
inplanes=128 * self.block_expansion,
planes=256,
blocks=layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
dilation=self.dilations[2],
**self.reslayer_common_cfg
),
# layer 4
dict(
inplanes=256 * self.block_expansion,
planes=512,
blocks=layers[3], stride=2,
dilate=replace_stride_with_dilation[2],
dilation=self.dilations[3],
**self.reslayer_common_cfg
),
# avg pool
dict(
type='AdaptiveAvgPool2d',
output_size=(1, 1)
),
# flatten
dict(
type='LambdaWrapper',
func=lambda mod, x: torch.flatten(x, 1)
),
# linear
dict(
type='Linear',
in_features=512 * self.block_expansion,
out_features=num_cls
)
]
def forward(self, x: Tensor):
for layer in self.layers:
x = layer(x)
return x,
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, LAYERS.get_module('ResNetBottleneck')):
# type: ignore[arg-type]
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')):
# type: ignore[arg-type]
nn.init.constant_(m.bn2.weight, 0)

View File

@ -1,3 +0,0 @@
from .vision_transformer import VisionTransformerFromConfig
__all__ = ['VisionTransformerFromConfig']

View File

@ -1,14 +1,10 @@
from .fp16_optimizer import FP16Optimizer from .colossalai_optimizer import ColossalaiOptimizer
from .fused_adam import FusedAdam from .fused_adam import FusedAdam
from .fused_lamb import FusedLAMB from .fused_lamb import FusedLAMB
from .fused_sgd import FusedSGD from .fused_sgd import FusedSGD
from .lamb import Lamb from .lamb import Lamb
from .lars import Lars from .lars import Lars
from .zero_redundancy_optimizer_level_1 import ZeroRedundancyOptimizer_Level_1
from .zero_redundancy_optimizer_level_2 import ZeroRedundancyOptimizer_Level_2
from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
__all__ = [ __all__ = [
'ZeroRedundancyOptimizer_Level_1', 'ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3', 'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars'
'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'FP16Optimizer', 'Lars'
] ]

View File

@ -1,168 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch._six import inf
try:
import colossal_C
except:
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
from ..multi_tensor_apply import multi_tensor_applier
from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
def _calc_l2_norm(grads):
norm = 0.0
if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
colossal_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False # no per-parameter norm
)
return norm
def _calc_lp(grads, norm_type):
norm = 0.0
for grad in grads:
grad_norm = torch.norm(grad, norm_type)
norm += grad_norm ** norm_type
return norm
# ======== Gradient Clipping =========
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
params = []
for param in parameters:
if param.grad is not None:
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor'
params.append(param)
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
# Calculate norm.
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in params)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if gpc.is_initialized(ParallelMode.TENSOR):
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.TENSOR))
total_norm = total_norm_cuda[0].item()
else:
tensor_parallel_grads = []
no_tensor_parallel_grads = []
for p in params:
if is_model_parallel_parameter(p):
tensor_parallel_grads.append(p.grad.data)
else:
no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0:
tensor_parallel_norm = _calc_l2_norm(
tensor_parallel_grads) ** norm_type
no_tensor_parallel_norm = _calc_l2_norm(
no_tensor_parallel_grads) ** norm_type
else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_grads = _calc_lp(
no_tensor_parallel_grads, norm_type)
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(tensor_parallel_norm,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR))
total_norm = (tensor_parallel_norm +
no_tensor_parallel_norm) ** (1.0 / norm_type)
if type(total_norm) == 'torch.cuda.FloatTensor':
total_norm = total_norm.item()
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale,
dummy_overflow_buf,
[grads, grads],
clip_coeff)
return total_norm
def count_zeros_fp32(parameters):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros = 0.0
for param in parameters:
grad_not_none = param.grad is not None
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_tp_duplicate:
grad = param.grad.detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR))
total_num_zeros = total_num_zeros.item()
return total_num_zeros
def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
for attr in TENSOR_PARALLEL_ATTRIBUTES:
if hasattr(src_tensor, attr):
val = getattr(src_tensor, attr)
setattr(dst_tensor, attr, val)
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, IS_TENSOR_PARALLEL) and
getattr(param, IS_TENSOR_PARALLEL)) or (
gpc.get_local_rank(ParallelMode.TENSOR) == 0)

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from colossalai.utils import clip_grad_norm_fp32
class ColossalaiOptimizer(Optimizer):
def __init__(self, optim: Optimizer):
self.optim = optim
@property
def param_groups(self):
return self.optim.param_groups
@property
def defaults(self):
return self.optim.defaults
def add_param_group(self, *args, **kwargs):
return self.optim.add_param_group(*args, **kwargs)
def step(self, *args, **kwargs):
return self.optim.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
self.optim.zero_grad(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
self.optim.load_state_dict(*args, **kwargs)
def state_dict(self):
return self.optim.state_dict()
def backward(self, loss: Tensor):
loss.backward()
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if max_norm > 0.0:
clip_grad_norm_fp32(model.parameters(), max_norm)

View File

@ -2,7 +2,7 @@
import torch import torch
from colossalai.registry import OPTIMIZERS from colossalai.registry import OPTIMIZERS
from ..multi_tensor_apply import multi_tensor_applier from colossalai.utils import multi_tensor_applier
@OPTIMIZERS.register_module @OPTIMIZERS.register_module

View File

@ -2,7 +2,7 @@
import torch import torch
from colossalai.registry import OPTIMIZERS from colossalai.registry import OPTIMIZERS
from ..multi_tensor_apply import multi_tensor_applier from colossalai.utils import multi_tensor_applier
@OPTIMIZERS.register_module @OPTIMIZERS.register_module

View File

@ -3,7 +3,7 @@ import torch
from torch.optim.optimizer import Optimizer, required from torch.optim.optimizer import Optimizer, required
from colossalai.registry import OPTIMIZERS from colossalai.registry import OPTIMIZERS
from ..multi_tensor_apply import multi_tensor_applier from colossalai.utils import multi_tensor_applier
@OPTIMIZERS.register_module @OPTIMIZERS.register_module

View File

@ -1,707 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from collections import defaultdict
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.optim import Optimizer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import OPTIMIZER_WRAPPERS
from colossalai.utils import get_current_device, print_rank_0
def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
sub_partition_high_limit = (sub_partition_id + 1) * sub_partition_size
if sub_partition_high_limit <= flattened_lean_size:
return 0
else:
return min(sub_partition_size, sub_partition_high_limit - flattened_lean_size)
def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_count):
group_paddings = []
flattened_size = sum([tensor.numel() for tensor in tensor_list])
for i in range(sub_partition_count):
padding = get_alignment_padding(flattened_size, i, sub_partition_size)
group_paddings.append(padding)
return group_paddings
def _single_range_check(current_index, start_index, end_index, tensor_size):
offset = 0
if (current_index >= start_index) and (current_index < end_index):
# Fully inside bounds
return True, offset
elif (start_index > current_index) and (start_index < (current_index + tensor_size)):
# Partially contained, compute offset
offset = start_index - current_index
return True, offset
else:
return False, offset
def _range_check(current_index, element_intervals, tensor_size):
results = []
for comm_idx, interval in enumerate(element_intervals):
start_index, end_index = interval
contained, offset = _single_range_check(
current_index, start_index, end_index, tensor_size)
if contained:
results.append((contained, offset, comm_idx))
if len(results) == 0:
return [(False, 0, -1)]
return results
@OPTIMIZER_WRAPPERS.register_module
class ZeroRedundancyOptimizer_Level_1(Optimizer):
"""
ZeroRedundancyOptimizer_Level_1 designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
This version aligns with stage-1 in the paper above.
"""
def __init__(self,
init_optimizer: Optimizer,
dp_parallel_mode: ParallelMode = ParallelMode.DATA,
max_elements_per_comm=5e8,
verbose=False
):
# TODO: this class does not work with fp16 AMP_TYPE.PARALLEL, fix it
assert get_current_device() != 'cpu', 'ZeRO optimizer cannot be used on CPU only'
self.flatten = _flatten_dense_tensors
self.unflatten = _unflatten_dense_tensors
self.optimizer = init_optimizer
self.dp_parallel_mode = dp_parallel_mode
self.verbose = verbose
# for compatibility with pytorch optim
self.defaults = init_optimizer.defaults
# param flattened by groups
self._param_groups = []
self._param_groups_flat = []
# parallel_sub_partitioned_fp16_groups[group-idx] -> [comm-ids] -> [rank-ids]
self.parallel_sub_partitioned_groups = []
# same underlying data as above but viewed as: [groups] -> [rank-ids] -> [comm-ids]
self.parallel_comm_sub_partitioned_groups = []
# param partition info
# parameters in each group that will not be updated by this process directly
self.params_not_local = []
# parameters that will be updated by this process directly
self.params_in_rank_sub_partitions = []
# parameter offsets for parameters in sub-partitions. Parameter
# boundaries may not align with sub-partition boundaries
# so we need to keep track of the offsets
self.params_in_rank_sub_partitions_offsets = []
# number of elements per sub-partition in each group
self.sub_partition_sizes = []
# number of communication intervals for each group
self.num_comm_intervals_per_group = []
self.local_rank = gpc.get_local_rank(self.dp_parallel_mode)
self.partition_count = self.world_size = gpc.get_world_size(
self.dp_parallel_mode)
self.group_paddings = []
self.default_device = self.optimizer.param_groups[0]['params'][0].device
# max elems per param group
self.max_elems_per_comm = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self._param_groups.append(param_group['params'])
# calculate best max elements per comm based to minimize padding
self.max_elems_per_comm.append(
self.best_max_elems_per_comm(
num_elements=sum(t.numel() for t in self._param_groups[i]),
max_elements_per_comm=max_elements_per_comm
)
)
# flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
# RS: create aligned sub-partitions
flat_aligned_params = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=self._param_groups[i],
max_elements_per_comm=self.max_elems_per_comm[i],
)
self._param_groups_flat.append(flat_aligned_params)
updated_params = self.unflatten(self._param_groups_flat[i],
self._param_groups[i])
for p, q in zip(self._param_groups[i], updated_params):
p.data = q.data
# divide the flat weights into near equal partition equal to the data parallel degree
# each process will compute on a different part of the partition
# RS: split into two layer list -> [comm-id] -> [sub-partitions per rank]
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=self._param_groups_flat[i],
max_elements_per_comm=self.max_elems_per_comm[i],
)
self.parallel_comm_sub_partitioned_groups.append(
comm_partitions) # comm -> rank
self.parallel_sub_partitioned_groups.append(
dp_sub_partitions) # rank -> comm
self.sub_partition_sizes.append(sub_partition_size)
self.num_comm_intervals_per_group.append(num_comm_intervals)
# Compute sub_partition paddings
sub_partition_paddings = get_group_alignment_padding(
tensor_list=self._param_groups[i],
sub_partition_size=sub_partition_size,
sub_partition_count=num_comm_intervals * self.partition_count)
self.group_paddings.append(sub_partition_paddings)
# modify optimizer of have flat master weight
param_group['params'] = self.parallel_sub_partitioned_groups[i][self.local_rank]
# RS: divide up the sub-partitions and keep track of offsets for each param
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local = self.get_all_sub_partition_info(
tensor_list=self._param_groups[i],
all_element_intervals=element_intervals,
)
self.params_in_rank_sub_partitions.append(
params_in_rank_sub_partition)
self.params_not_local.append(params_not_local)
self.params_in_rank_sub_partitions_offsets.append(
params_in_rank_sub_partitions_offsets)
self.local_sub_partitions_of_groups = [
group[self.local_rank] for group in self.parallel_sub_partitioned_groups]
self._initialize_optimizer_states()
@property
def state(self):
return self.optimizer.state
@state.setter
def state(self, value):
self.optimizer.state = value
@property
def param_groups(self):
# LSG: return the full param groups instead of local partitions
# of the param groups for compatibility with torch.cuda.amp
param_groups = []
for group_id, group in enumerate(self.optimizer.param_groups):
group_containing_all_param = {
'params': self._param_groups[group_id],
**{k: v for k, v in group.items() if k != 'params'}
}
# LSG: for compatibility with unknown bug with lr scheduler
# TODO: fix this
group_containing_all_param.setdefault('initial_lr', group['lr'])
param_groups.append(group_containing_all_param)
return param_groups
@param_groups.setter
def param_groups(self, value):
self.optimizer.param_groups = value
def _initialize_optimizer_states(self):
for group_idx, group in enumerate(self.local_sub_partitions_of_groups):
for idx, sub_partition_param in enumerate(group):
sub_partition_grad = torch.zeros(int(
self.sub_partition_sizes[group_idx]),
dtype=sub_partition_param.dtype).cuda()
sub_partition_param.grad = sub_partition_grad
self.optimizer.step()
# LSG: comment out for compatibility with torch.cuda.amp
# for group in self.local_sub_partitions_of_groups:
# for idx, sub_partition_param in enumerate(group):
# sub_partition_param.grad = None
def best_max_elems_per_comm(self, num_elements, max_elements_per_comm):
# if we use max-elems-per-comm as is, how many comm intervals will there be
max_comm_intervals = math.ceil(num_elements / max_elements_per_comm)
padding_for_max_comm = (max_elements_per_comm *
max_comm_intervals) - num_elements
# if we use 1 less comm interval how much extra comm padding would be required
min_comm_intervals = num_elements // max_elements_per_comm
if min_comm_intervals == 0:
if self.verbose:
print_rank_0(
f'Using default max_elements_per_comm {max_elements_per_comm}')
return max_elements_per_comm
padding_for_min_comm = math.ceil(
num_elements / (self.world_size * min_comm_intervals))
# choose padding that uses least amount of overhead
if padding_for_max_comm > padding_for_min_comm:
new_max_elements_per_comm = padding_for_min_comm + max_elements_per_comm
if self.verbose:
print_rank_0(
f'Updating max_elements_per_comm from {max_elements_per_comm} -> {new_max_elements_per_comm}')
return new_max_elements_per_comm
else:
if self.verbose:
print_rank_0(
f'Using default max_elements_per_comm {max_elements_per_comm}')
return max_elements_per_comm
def get_data_parallel_sub_partitions(self,
tensor,
max_elements_per_comm,
):
total_num_elements = tensor.numel()
# if total elements is less than our max, revert to splitting into dp partitions
max_elements_per_comm = min(total_num_elements, max_elements_per_comm)
sub_partition_size = int(max_elements_per_comm // self.world_size)
# Ensure partition alignment was done correctly
num_sub_partitions = int(total_num_elements // sub_partition_size)
assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format(total_num_elements,
sub_partition_size)
# Ensure comm interval alignment was done correctly.
num_comm_intervals = int(num_sub_partitions // self.world_size)
assert num_sub_partitions % self.world_size == 0, "{} % {} != 0".format(
num_sub_partitions, self.world_size)
if self.verbose:
print_rank_0("**** partition info:")
print_rank_0(f"\t total_num_elements={total_num_elements}")
print_rank_0(f"\t world_size={self.world_size}")
print_rank_0(f"\t max_elements_per_comm={max_elements_per_comm}")
print_rank_0(f"\t sub_partition_size={sub_partition_size}")
print_rank_0(f"\t num_sub_partitions={num_sub_partitions}")
print_rank_0(f"\t num_comm_intervals={num_comm_intervals}")
print_rank_0("****")
# [comm_id] -> [rank]
comm_partitions = []
for _ in range(num_comm_intervals):
comm_partitions.append([])
start = 0
comm_id = 0
element_intervals = defaultdict(
list) # [rank] -> [(start,end), (start,end), ...]
for idx in range(num_sub_partitions):
rank_id = idx % self.world_size
sub_partition = tensor.narrow(
0, start, sub_partition_size).detach()
element_intervals[rank_id].append(
(start, start + sub_partition_size))
comm_partitions[comm_id].append(sub_partition)
start = start + sub_partition_size
if rank_id == (self.world_size - 1):
comm_id += 1
# [rank] -> [comm_id]
sub_partitions = []
for _ in range(self.world_size):
sub_partitions.append([])
for comm_id, partitions in enumerate(comm_partitions):
for rank_id, partition in enumerate(partitions):
sub_partitions[rank_id].append(partition)
return comm_partitions, sub_partitions, element_intervals, sub_partition_size, num_comm_intervals
def get_all_sub_partition_info(self,
tensor_list,
all_element_intervals,
):
params_not_local = []
# [rank] -> [comm-id] -> [param/offset]
params_in_rank_sub_partition = []
params_in_rank_sub_partitions_offsets = []
for rank in range(self.world_size):
params_in_local_sub_partition = []
local_sub_partition_offsets = []
comm_tensor_list = []
comm_offset_list = []
current_index = 0
prev_comm_idx = 0
for iii, tensor in enumerate(tensor_list):
tensor_size = tensor.numel()
results_list = _range_check(current_index,
all_element_intervals[rank],
tensor_size)
for contained, offset, comm_idx in results_list:
if contained:
if prev_comm_idx != comm_idx:
params_in_local_sub_partition.append(
comm_tensor_list)
comm_tensor_list = []
local_sub_partition_offsets.append(
comm_offset_list)
comm_offset_list = []
comm_tensor_list.append(tensor)
comm_offset_list.append(offset)
prev_comm_idx = comm_idx
elif rank == self.local_rank:
params_not_local.append(tensor)
current_index = current_index + tensor_size
# assert len(comm_tensor_list) > 0
# assert len(comm_offset_list) > 0
params_in_local_sub_partition.append(comm_tensor_list)
local_sub_partition_offsets.append(comm_offset_list)
params_in_rank_sub_partition.append(params_in_local_sub_partition)
params_in_rank_sub_partitions_offsets.append(
local_sub_partition_offsets)
return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local
def get_flat_sub_partitions(self,
comm_tensor_list,
comm_param_offsets,
sub_partition_size,
dtype,
default_device,
num_comm_intervals=None,
return_partition_params=False):
partition_params = []
final_param_offsets = []
flat_sub_partitions = []
for tensor_list, param_offsets in zip(comm_tensor_list, comm_param_offsets):
flat_tensor_list = []
current_size = 0
my_offsets = []
my_params = []
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
tensor.grad = torch.zeros(tensor.size(),
dtype=tensor.dtype,
device=tensor.device)
param = tensor
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
# we need to offset to get to the right element
if i == 0 and param_offsets[i] > 0:
tensor_offset = param_offsets[i]
num_elements = num_elements - tensor_offset
# We don't need all elements of the tensor if this tensor is
# larger than we have space for in our curr sub-partition
if num_elements > (sub_partition_size - current_size):
num_elements = sub_partition_size - current_size
# we need a narrow view of the tensor based on the tensor offset and number of elements that
# we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)).to(dtype))
else:
flat_tensor_list.append(tensor.to(dtype))
my_params.append(param)
# remember offset into partition and #elems for this tensor
my_offsets.append((current_size, num_elements))
current_size = current_size + num_elements
# this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < sub_partition_size:
my_offsets.append((None, None))
my_params.append(None)
if len(tensor_list) == 0:
assert default_device != None
flat_tensor_list.append(
torch.zeros(int(sub_partition_size - current_size),
dtype=dtype,
device=default_device))
else:
flat_tensor_list.append(
torch.zeros(int(sub_partition_size - current_size),
dtype=dtype,
device=tensor_list[0].device))
partition_params.append(my_params) # flat_tensor_list)
final_param_offsets.append(my_offsets)
assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(
len(flat_tensor_list), len(my_offsets))
flat_sub_partitions.append(self.flatten(flat_tensor_list))
if num_comm_intervals is not None and len(
flat_sub_partitions) < num_comm_intervals:
# print("padding w. sub partitions to ensure uniform communication")
device = flat_sub_partitions[0].device
for _ in range(num_comm_intervals - len(flat_sub_partitions)):
flat_sub_partitions.append(
torch.zeros(int(sub_partition_size),
dtype=dtype,
device=device))
partition_params.append([None])
final_param_offsets.append([(None, None)])
if return_partition_params:
assert len(flat_sub_partitions) == len(partition_params)
assert len(partition_params) == len(final_param_offsets), "{} {}".format(len(partition_params),
len(final_param_offsets))
return flat_sub_partitions, partition_params, final_param_offsets
return flat_sub_partitions
def zero_grad(self, set_grads_to_None=False):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self._param_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def free_grad_in_param_list(self, param_list):
for p in param_list:
if isinstance(p, list):
for _p in p:
_p.grad = None
else:
p.grad = None
def flatten_dense_tensors_sub_partition_aligned(self,
tensor_list,
max_elements_per_comm
):
assert max_elements_per_comm >= self.world_size, f"max_elements_per_comm {max_elements_per_comm} < dp {self.world_size}"
num_elements = sum(t.numel() for t in tensor_list)
# Compute aligned partition size based on parameter count
aligned_param_partition_size = math.ceil(
num_elements / self.world_size)
# Compute aligned partition size based on communication size
aligned_comm_partition_size = int(
max_elements_per_comm // self.world_size)
if aligned_param_partition_size <= aligned_comm_partition_size:
sub_partition_count = 1
sub_partition_size = aligned_param_partition_size
else:
sub_partition_count = math.ceil(aligned_param_partition_size /
aligned_comm_partition_size)
sub_partition_size = aligned_comm_partition_size
# Compute required padding for alignment to dp and max_elements_per_comm
padding = (sub_partition_count * sub_partition_size *
self.world_size) - num_elements
if self.verbose:
print_rank_0(
f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}")
print_rank_0(
f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}")
if padding == 0:
aligned_tensor_list = tensor_list
else:
pad_tensor = torch.zeros(padding,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
aligned_tensor_list = tensor_list + [pad_tensor]
flat_tensors = self.flatten(aligned_tensor_list)
return flat_tensors
# def reduce_gradients(self):
# # LSG: this reduce gradients method no longer works
# # after code change, please use DataParallelGradientHandler instead
#
# world_size = gpc.get_world_size(self.parallel_mode)
# local_rank = gpc.get_local_rank(self.parallel_mode)
#
# for i, group in enumerate(self._param_groups):
# num_comm_intervals = self.num_comm_intervals_per_group[i]
# all_sub_partitions = []
# for rank in range(world_size):
# # gsp is list of partitions indexed by comm_idx
# grad_sub_partitions = self.get_flat_sub_partitions(
# comm_tensor_list=self.params_in_rank_sub_partitions[i][rank],
# comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][rank],
# dtype=self.local_sub_partitions_of_groups[i][0].dtype,
# default_device=self.default_device,
# sub_partition_size=self.sub_partition_sizes[i],
# num_comm_intervals=self.num_comm_intervals_per_group[i])
# all_sub_partitions.append(grad_sub_partitions)
#
# assert len(grad_sub_partitions) == num_comm_intervals
#
# local_comm_partitions = []
# for comm_idx in range(num_comm_intervals):
# single_comm_all_partitions = []
# for rank in range(world_size):
# single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx])
#
# for partition in single_comm_all_partitions:
# partition.div_(world_size)
#
# dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
# input_list=single_comm_all_partitions,
# group=gpc.get_group(self.parallel_mode))
def step(self, closure=None):
local_sub_partitions_grad_groups = []
for i, group in enumerate(self._param_groups):
# RS: update free grads w.r.t. sub partitions
# free gradients for all the parameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_local[i])
# create flat gradient partitions for parameters updated by this process
local_grad_sub_partitions = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][self.local_rank],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][self.local_rank],
sub_partition_size=self.sub_partition_sizes[i],
dtype=self.local_sub_partitions_of_groups[i][0].dtype,
num_comm_intervals=self.num_comm_intervals_per_group[i],
default_device=self.default_device)
# RS: update all our local params with sub-partition grads
for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_groups[i]):
sub_partition_param.grad = local_grad_sub_partitions[idx]
# RS: update free grads for sub-partitions
# release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(
self.params_in_rank_sub_partitions[i][self.local_rank])
local_sub_partitions_grad_groups.append(local_grad_sub_partitions)
if closure is None:
loss = self.optimizer.step()
else:
loss = self.optimizer.step(closure=closure)
# RS: clear our sub partition grads
# LSG: not needed as amp is used instead
# get rid of the fp32 gradients. Not needed anymore
# for group in self.local_sub_partitions_of_groups:
# for idx, sub_partition_param in enumerate(group):
# sub_partition_param.grad = None
# RS: all_gather/broadcast sub-partitions in separate comm calls
# gather the updated weights from everyone
for all_sub_partitions in self.parallel_comm_sub_partitioned_groups:
for comm_id, sub_partitions in enumerate(all_sub_partitions):
dist.all_gather(sub_partitions,
sub_partitions[self.local_rank],
group=gpc.get_group(self.dp_parallel_mode))
# TODO: we probably don't need this? just to be safe
for i in range(len(self._param_groups)):
updated_params = self.unflatten(self._param_groups_flat[i],
self._param_groups[i])
for p, q in zip(self._param_groups[i], updated_params):
p.data = q.data
return loss
def _rigid_state_dict(self):
"""Returns a dict that can be loaded for continued training with same DP degree
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
for k, v in self.optimizer.state_dict().items():
state_dict[k] = v
state_dict[
'local_sub_partitions_of_groups'] = self.local_sub_partitions_of_groups
return state_dict
def state_dict(self):
"""
Returns a dict containing the current state of this Optimizer instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
return self._rigid_state_dict()
def load_state_dict(self,
state_dict,
load_optimizer_states=True,
):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
self._rigid_load_state_dict(
state_dict,
load_optimizer_states)
def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
# I think it should actually be ok to reload the optimizer before the model.
state_dict_ = state_dict.copy()
local_sub_partitions_of_groups = state_dict_.pop(
'local_sub_partitions_of_groups')
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict_)
for curr_group, saved_group in zip(self.local_sub_partitions_of_groups,
local_sub_partitions_of_groups):
for curr_param, saved_param in zip(curr_group, saved_group):
curr_param.data.copy_(saved_param.data)

Some files were not shown because too many files have changed in this diff Show More