mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-16 16:32:52 +00:00
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> * Split conv2d, class token, positional embedding in 2d, Fix random number in ddp Fix convergence in cifar10, Imagenet1000 * Integrate 1d tensor parallel in Colossal-AI (#39) * fixed 1D and 2D convergence (#38) * optimized 2D operations * fixed 1D ViT convergence problem * Feature/ddp (#49) * remove redundancy func in setup (#19) (#20) * use env to control the language of doc (#24) (#25) * Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) * add explanation for ViT example (#35) (#36) * support torch ddp * fix loss accumulation * add log for ddp * change seed * modify timing hook Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * Feature/pipeline (#40) * remove redundancy func in setup (#19) (#20) * use env to control the language of doc (#24) (#25) * Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) * add explanation for ViT example (#35) (#36) * optimize communication of pipeline parallel * fix grad clip for pipeline Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51) * Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset * update api for better usability (#58) update api for better usability Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
parent
eb2f8b1f6b
commit
da01c234e1
@ -1,4 +1,4 @@
|
||||
from .initialize import init_dist, initialize
|
||||
from .nn import *
|
||||
from .initialize import (initialize, launch, launch_from_openmpi,
|
||||
launch_from_slurm, launch_from_torch, get_default_parser)
|
||||
|
||||
__version__ = '0.0.1'
|
||||
|
32
colossalai/amp/__init__.py
Normal file
32
colossalai/amp/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from .amp_type import AMP_TYPE
|
||||
from colossalai.context import Config
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from .torch_amp import convert_to_torch_amp
|
||||
from .apex_amp import convert_to_apex_amp
|
||||
from .naive_amp import convert_to_naive_amp
|
||||
|
||||
|
||||
def convert_to_amp(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: _Loss,
|
||||
mode: AMP_TYPE,
|
||||
amp_config: Config = None):
|
||||
assert isinstance(mode, AMP_TYPE), \
|
||||
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
|
||||
|
||||
if amp_config is None:
|
||||
amp_config = Config()
|
||||
|
||||
if mode == AMP_TYPE.TORCH:
|
||||
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
|
||||
elif mode == AMP_TYPE.APEX:
|
||||
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
|
||||
elif mode == AMP_TYPE.NAIVE:
|
||||
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
|
||||
|
||||
return model, optimizer, criterion
|
@ -7,4 +7,4 @@ from enum import Enum
|
||||
class AMP_TYPE(Enum):
|
||||
APEX = 'apex'
|
||||
TORCH = 'torch'
|
||||
PARALLEL = 'parallel'
|
||||
NAIVE = 'naive'
|
15
colossalai/amp/apex_amp/__init__.py
Normal file
15
colossalai/amp/apex_amp/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
from .apex_amp import ApexAMPOptimizer
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
import apex.amp as apex_amp
|
||||
|
||||
|
||||
def convert_to_apex_amp(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
amp_config):
|
||||
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
|
||||
optimizer = ApexAMPOptimizer(optimizer)
|
||||
return model, optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']
|
23
colossalai/amp/apex_amp/apex_amp.py
Normal file
23
colossalai/amp/apex_amp/apex_amp.py
Normal file
@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
try:
|
||||
import apex.amp as apex_amp
|
||||
except:
|
||||
pass
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import clip_grad_norm_fp32
|
||||
|
||||
|
||||
class ApexAMPOptimizer(ColossalaiOptimizer):
|
||||
|
||||
def backward(self, loss: Tensor):
|
||||
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
if max_norm > 0:
|
||||
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)
|
20
colossalai/amp/naive_amp/__init__.py
Normal file
20
colossalai/amp/naive_amp/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.utils import is_no_pp_or_last_stage
|
||||
|
||||
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
|
||||
|
||||
|
||||
def convert_to_naive_amp(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
amp_config):
|
||||
if is_no_pp_or_last_stage():
|
||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||
else:
|
||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||
|
||||
optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
|
||||
return model, optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer']
|
@ -12,11 +12,9 @@ from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.registry import OPTIMIZER_WRAPPERS
|
||||
from colossalai.utils import print_rank_0
|
||||
from ._utils import copy_tensor_parallel_attributes, clip_grad_norm_fp32, count_zeros_fp32
|
||||
from ..multi_tensor_apply import multi_tensor_applier
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
|
||||
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)
|
||||
|
||||
|
||||
def _zero_grad_group_helper(group, set_to_none):
|
||||
@ -92,7 +90,7 @@ class DynamicGradScaler:
|
||||
self._growth_tracker = 0
|
||||
self._hysteresis_tracker = self.hysteresis
|
||||
|
||||
self._logger = get_global_dist_logger()
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
@ -113,7 +111,7 @@ class DynamicGradScaler:
|
||||
if self._hysteresis_tracker <= 0:
|
||||
self._scale = torch.max(self._scale * self.backoff_factor,
|
||||
self.min_scale)
|
||||
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}')
|
||||
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
|
||||
else:
|
||||
# If there is no nan/inf, increment the growth tracker.
|
||||
self._growth_tracker += 1
|
||||
@ -125,10 +123,10 @@ class DynamicGradScaler:
|
||||
# and scale up the loss scale.
|
||||
if self._max_scale is not None and self._scale >= self._max_scale:
|
||||
self._logger.info(
|
||||
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed')
|
||||
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
|
||||
else:
|
||||
self._scale = self._scale * self.growth_factor
|
||||
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}')
|
||||
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {}
|
||||
@ -145,7 +143,6 @@ class DynamicGradScaler:
|
||||
self._max_scale = state_dict['max_scale']
|
||||
|
||||
|
||||
@OPTIMIZER_WRAPPERS.register_module
|
||||
class FP16Optimizer(Optimizer):
|
||||
"""Float16 optimizer for fp16 and bf16 data types.
|
||||
|
||||
@ -184,13 +181,13 @@ class FP16Optimizer(Optimizer):
|
||||
max_scale: int = 2 ** 32):
|
||||
# default args for compatibility
|
||||
bf16 = False
|
||||
params_have_main_grad = False
|
||||
params_have_main_grad = True
|
||||
|
||||
# have a defaults for compatibility with pytorch optim
|
||||
self.defaults = optimizer.defaults
|
||||
|
||||
# log config
|
||||
self._logger = get_global_dist_logger()
|
||||
self._logger = get_dist_logger()
|
||||
self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
|
||||
f"Optimizer: {optimizer.__class__.__name__}\n"
|
||||
f"clip_grad = {clip_grad}\n"
|
||||
@ -328,6 +325,7 @@ class FP16Optimizer(Optimizer):
|
||||
else:
|
||||
if model_param.grad is not None:
|
||||
main_param.grad = model_param.grad.float()
|
||||
|
||||
# For fp32 grads, we need to reset the grads to main grad.
|
||||
if self.params_have_main_grad:
|
||||
for model_group in self.fp32_from_fp32_groups:
|
||||
@ -387,10 +385,6 @@ class FP16Optimizer(Optimizer):
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
# for param_group in self.float16_groups:
|
||||
# for param in param_group:
|
||||
# print(param.grad is None)
|
||||
|
||||
# Copy gradients from model params to main params.
|
||||
self._copy_model_grads_to_main_grads()
|
||||
|
65
colossalai/amp/naive_amp/naive_amp.py
Normal file
65
colossalai/amp/naive_amp/naive_amp.py
Normal file
@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import Union, List, Any, Dict
|
||||
from torch.optim import Optimizer
|
||||
import torch.cuda.amp as torch_amp
|
||||
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from ._fp16_optimizer import FP16Optimizer
|
||||
|
||||
|
||||
class NaiveAMPOptimizer(ColossalaiOptimizer):
|
||||
|
||||
def __init__(self, optim: Optimizer, *args, **kwargs):
|
||||
optim = FP16Optimizer(optimizer=optim, *args, **kwargs)
|
||||
super().__init__(optim)
|
||||
|
||||
def backward(self, loss: Tensor):
|
||||
loss = self.optim.scale_loss(loss)
|
||||
loss.backward()
|
||||
|
||||
def step(self):
|
||||
self.optim.step()
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
pass
|
||||
|
||||
|
||||
class NaiveAMPModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
output_to_fp32: bool = True):
|
||||
super().__init__()
|
||||
self.model = model.half()
|
||||
self._output_to_fp32 = output_to_fp32
|
||||
|
||||
def _convert_to_fp16(self, input_: Any):
|
||||
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
|
||||
input_ = input_.half()
|
||||
return input_
|
||||
|
||||
def _convert_to_fp32(self, input_: Any):
|
||||
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
|
||||
input_ = input_.float()
|
||||
return input_
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if args:
|
||||
args = [self._convert_to_fp16(arg) for arg in args]
|
||||
if kwargs:
|
||||
for k, v in kwargs.items():
|
||||
kwargs[k] = self._convert_to_fp16(v)
|
||||
|
||||
out = self.model(*args, **kwargs)
|
||||
|
||||
if self._output_to_fp32:
|
||||
if isinstance(out, Tensor):
|
||||
out = self._convert_to_fp32(out)
|
||||
elif isinstance(out, (tuple, list)):
|
||||
out = [self._convert_to_fp32(val) for val in out]
|
||||
return out
|
18
colossalai/amp/torch_amp/__init__.py
Normal file
18
colossalai/amp/torch_amp/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from colossalai.context import Config
|
||||
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
|
||||
|
||||
|
||||
def convert_to_torch_amp(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: _Loss,
|
||||
amp_config: Config):
|
||||
model = TorchAMPModel(model)
|
||||
optimizer = TorchAMPOptimizer(optimizer, **amp_config)
|
||||
criterion = TorchAMPLoss(criterion)
|
||||
return model, optimizer, criterion
|
||||
|
||||
|
||||
__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer']
|
@ -1,4 +1,8 @@
|
||||
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
|
||||
# to support tensor parallel
|
||||
|
||||
import torch
|
||||
from collections import defaultdict, abc
|
||||
import warnings
|
54
colossalai/amp/torch_amp/torch_amp.py
Normal file
54
colossalai/amp/torch_amp/torch_amp.py
Normal file
@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.cuda.amp as torch_amp
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
from ._grad_scaler import GradScaler
|
||||
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import clip_grad_norm_fp32
|
||||
|
||||
|
||||
class TorchAMPOptimizer(ColossalaiOptimizer):
|
||||
|
||||
def __init__(self, optim: Optimizer, *args, **kwargs):
|
||||
super().__init__(optim)
|
||||
self.scaler = GradScaler(*args, **kwargs)
|
||||
|
||||
def backward(self, loss: Tensor):
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
def step(self):
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
if max_norm > 0.0:
|
||||
self.scaler.unscale_(self.optim)
|
||||
clip_grad_norm_fp32(model.parameters(), max_norm)
|
||||
|
||||
|
||||
class TorchAMPModel(nn.Module):
|
||||
|
||||
def __init__(self, model: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
@torch_amp.autocast()
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
|
||||
class TorchAMPLoss(nn.Module):
|
||||
|
||||
def __init__(self, loss: _Loss):
|
||||
super().__init__()
|
||||
self.loss = loss
|
||||
|
||||
@torch_amp.autocast()
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.loss(*args, **kwargs)
|
@ -1,10 +1,10 @@
|
||||
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper,
|
||||
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
|
||||
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
|
||||
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
|
||||
build_gradient_handler)
|
||||
from .pipeline import ModelInitializer
|
||||
from .pipeline import PipelineModelInitializer
|
||||
|
||||
__all__ = [
|
||||
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper',
|
||||
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
|
||||
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
|
||||
'build_gradient_handler', 'ModelInitializer'
|
||||
'build_gradient_handler', 'PipelineModelInitializer'
|
||||
]
|
||||
|
@ -106,7 +106,7 @@ def build_dataset(config):
|
||||
return build_from_registry(config, DATASETS)
|
||||
|
||||
|
||||
def build_optimizer(config, model, params: Iterable = None, need_module=False):
|
||||
def build_optimizer(config, model):
|
||||
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
|
||||
'model' and 'params'.
|
||||
|
||||
@ -115,23 +115,12 @@ def build_optimizer(config, model, params: Iterable = None, need_module=False):
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param model: A model containing parameters for the optimizer
|
||||
:type model: :class:`nn.Module`
|
||||
:param params: A dict containing parameters for the optimizer
|
||||
:type params: dict, optional
|
||||
:param need_module: Indicates whether the optimizer needs a module
|
||||
:type params: bool, optional
|
||||
:raises AssertionError: Raises an AssertionError if both `model` and `params` are None
|
||||
:return: An object of :class:`torch.optim.Optimizer`
|
||||
:rtype: :class:`torch.optim.Optimizer`
|
||||
"""
|
||||
assert model is not None or params is not None, 'arguments model and params can not both be None'
|
||||
if need_module:
|
||||
config['module'] = model
|
||||
elif model is not None:
|
||||
config['params'] = model.parameters()
|
||||
elif params is not None:
|
||||
config['params'] = params
|
||||
|
||||
return build_from_registry(config, OPTIMIZERS)
|
||||
config_ = config.copy()
|
||||
config_['params'] = model.parameters()
|
||||
return build_from_registry(config_, OPTIMIZERS)
|
||||
|
||||
|
||||
def build_gradient_handler(config, model, optimizer):
|
||||
@ -149,8 +138,9 @@ def build_gradient_handler(config, model, optimizer):
|
||||
:rtype: :class:`BaseGradientHandler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
mod_type = config_.pop('type')
|
||||
return GRADIENT_HANDLER.get_module(mod_type)(model, optimizer, **config_)
|
||||
config_['model'] = model
|
||||
config_['optimizer'] = optimizer
|
||||
return build_from_registry(config_, GRADIENT_HANDLER)
|
||||
|
||||
|
||||
def build_hooks(config, trainer):
|
||||
@ -164,8 +154,9 @@ def build_hooks(config, trainer):
|
||||
:return: An object of :class:`BaseHook`
|
||||
:rtype: :class:`BaseHook`
|
||||
"""
|
||||
config['trainer'] = trainer
|
||||
return build_from_registry(config, HOOKS)
|
||||
config_ = config.copy()
|
||||
config_['trainer'] = trainer
|
||||
return build_from_registry(config_, HOOKS)
|
||||
|
||||
|
||||
def build_transform(config):
|
||||
@ -195,32 +186,8 @@ def build_data_sampler(config, dataset):
|
||||
:rtype: :class:`colossalai.nn.data.sampler.BaseSampler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
mod_type = config_.pop('type')
|
||||
return SAMPLERS.get_module(mod_type)(dataset, **config_)
|
||||
|
||||
|
||||
def build_optimizer_wrapper(config, optimizer, model=None):
|
||||
"""Returns an optimizer wrapper object of :class:`torch.optim.Optimizer` constructed
|
||||
from `config`, `model` and `optimizer`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param optimizer: An optimizer object containing parameters for the gradient handler
|
||||
:type optimizer: :class:`torch.optim.Optimizer`
|
||||
:param model: A model containing parameters for the gradient handler
|
||||
:type model: :class:`nn.Module`, optional
|
||||
:return: An object of :class:`torch.optim.Optimizer`
|
||||
:rtype: :class:`torch.optim.Optimizer`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
mod_type = config_.pop('type')
|
||||
|
||||
# LSG: special treatment for zeor level 3
|
||||
if mod_type == 'ZeroRedundancyOptimizer_Level_3':
|
||||
return OPTIMIZER_WRAPPERS.get_module(mod_type)(model, optimizer, **config_)
|
||||
else:
|
||||
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
|
||||
config_['dataset'] = dataset
|
||||
return build_from_registry(config_, DATA_SAMPLERS)
|
||||
|
||||
|
||||
def build_lr_scheduler(config, optimizer):
|
||||
@ -241,8 +208,8 @@ def build_lr_scheduler(config, optimizer):
|
||||
:rtype: :class:`torch.optim.lr_scheduler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
mod_type = config_.pop('type')
|
||||
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
|
||||
config_['optimizer'] = optimizer
|
||||
return build_from_registry(config_, LR_SCHEDULERS)
|
||||
|
||||
|
||||
def build_schedule(config):
|
||||
|
@ -4,7 +4,7 @@ import heapq
|
||||
from colossalai.builder import build_model, build_layer
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import set_to_cuda
|
||||
|
||||
|
||||
@ -111,21 +111,21 @@ def _binary_search(weights, num):
|
||||
return intervals
|
||||
|
||||
|
||||
def _partition_uniform(num_items, num_parts, num_chunks):
|
||||
def _partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||
assert num_items % num_chunks == 0, \
|
||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
||||
|
||||
logger = get_global_dist_logger()
|
||||
parts = [[] for _ in range(num_parts)]
|
||||
logger = get_dist_logger()
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
partition_items = num_items // num_chunks
|
||||
for idx in range(num_chunks):
|
||||
base_idx = idx * partition_items
|
||||
chunk_size = partition_items // num_parts
|
||||
left = num_parts - partition_items % num_parts
|
||||
chunk_size = partition_items // pipeline_parallel_size
|
||||
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
|
||||
if chunk_size == 0:
|
||||
logger.warning("Some nodes in Pipeline have no requests")
|
||||
|
||||
for p in range(num_parts):
|
||||
for p in range(pipeline_parallel_size):
|
||||
st = base_idx
|
||||
base_idx += chunk_size + (p >= left)
|
||||
parts[p].append((st, base_idx))
|
||||
@ -133,34 +133,34 @@ def _partition_uniform(num_items, num_parts, num_chunks):
|
||||
return parts
|
||||
|
||||
|
||||
def _partition_balanced(weights, num_parts, num_chunks):
|
||||
num_total = num_parts * num_chunks
|
||||
def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
||||
num_total = pipeline_parallel_size * num_chunks
|
||||
num_items = len(weights)
|
||||
if num_items <= num_total:
|
||||
return _partition_uniform(num_items, num_parts, num_chunks)
|
||||
return _partition_uniform(num_items, pipeline_parallel_size, num_chunks)
|
||||
|
||||
intervals = _binary_search(weights, num_total)
|
||||
|
||||
current = 0
|
||||
parts = [[] for _ in range(num_parts)]
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
for inter in intervals:
|
||||
parts[current].append(inter)
|
||||
current = (current + 1) % num_parts
|
||||
current = (current + 1) % pipeline_parallel_size
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
class ModelInitializer():
|
||||
class PipelineModelInitializer():
|
||||
def __init__(self, config, num_chunks, verbose=False):
|
||||
self.num_chunks = num_chunks
|
||||
self.ori_model = build_model(config)
|
||||
self.layers = self.ori_model.layers_cfg
|
||||
layer_length = len(self.layers)
|
||||
self.verbose = verbose
|
||||
self._logger = get_global_dist_logger()
|
||||
self._logger = get_dist_logger()
|
||||
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
|
||||
|
||||
def model_initialize(self, partition_method='parameter'):
|
||||
def initialize(self, partition_method='parameter'):
|
||||
# Some space for initializing comunication groups
|
||||
self._interval = None
|
||||
self._partition_layers(method=partition_method)
|
||||
@ -198,7 +198,7 @@ class ModelInitializer():
|
||||
for st, ed in self.parts[stage]:
|
||||
for idx, layer in enumerate(self.layers[st: ed]):
|
||||
log_str += f'\t{idx + st:2d}: {layer}\n'
|
||||
self._logger.info(log_str)
|
||||
self._logger.info(log_str, ranks=[0])
|
||||
|
||||
# Save the partition
|
||||
self._interval = self.parts[pipeline_rank]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .collective import all_gather, reduce_scatter, scatter
|
||||
from .collective import all_gather, reduce_scatter, all_reduce
|
||||
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
|
||||
send_backward, send_backward_recv_backward, send_forward_recv_backward,
|
||||
send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
|
||||
@ -6,7 +6,7 @@ from .ring import ring_forward
|
||||
from .utils import send_tensor_meta, recv_tensor_meta
|
||||
|
||||
__all__ = [
|
||||
'all_gather', 'reduce_scatter', 'scatter',
|
||||
'all_gather', 'reduce_scatter', 'all_reduce',
|
||||
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
|
||||
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
|
||||
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
|
||||
|
@ -11,7 +11,7 @@ from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, dim: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
parallel_mode: ParallelMode, async_op=False) -> Tensor:
|
||||
"""Gathers all tensors from the parallel group and concatenates them in a
|
||||
specific dimension.
|
||||
|
||||
@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int,
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
temp = tensor.clone()
|
||||
shape = list(temp.shape)
|
||||
shape[dim] *= depth
|
||||
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
|
||||
out = list(torch.chunk(out, depth, dim=dim))
|
||||
out = [val.contiguous() for val in out]
|
||||
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
|
||||
out = torch.cat(out, dim=dim)
|
||||
return out
|
||||
# shape = list(temp.shape)
|
||||
# shape[dim] *= depth
|
||||
# out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
|
||||
# out = list(torch.chunk(out, depth, dim=dim))
|
||||
# out = [val.contiguous() for val in out]
|
||||
shape = [1] * len(tensor.shape)
|
||||
shape[dim] = depth
|
||||
out = tensor.repeat(shape)
|
||||
out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
|
||||
op = dist.all_gather(tensor_list=out,
|
||||
tensor=temp,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
# out = torch.cat(out, dim=dim)
|
||||
if async_op:
|
||||
return out, op
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def reduce_scatter(tensor: Tensor, dim: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
parallel_mode: ParallelMode, async_op=False) -> Tensor:
|
||||
"""Reduces all tensors then scatters it in a specific dimension to all
|
||||
members in the parallel group.
|
||||
|
||||
@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int,
|
||||
:rtype: Tensor
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
temp = list(torch.chunk(tensor, depth, dim=dim))
|
||||
temp = [val.contiguous() for val in temp]
|
||||
out = torch.empty(temp[0].shape,
|
||||
dtype=temp[0].dtype,
|
||||
device=get_current_device())
|
||||
dist.reduce_scatter(output=out,
|
||||
input_list=temp,
|
||||
group=gpc.get_group(parallel_mode))
|
||||
return out
|
||||
# temp = list(torch.chunk(tensor, depth, dim=dim))
|
||||
# temp = [val.contiguous() for val in temp]
|
||||
# out = torch.zeros(temp[0].shape,
|
||||
# dtype=temp[0].dtype,
|
||||
# device=get_current_device())
|
||||
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
|
||||
out = temp[0].clone()
|
||||
op = dist.reduce_scatter(output=out,
|
||||
input_list=temp,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
if async_op:
|
||||
return out, op
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def scatter(tensor: Tensor, src: int, dim: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
"""Scatters in a specific dimension from source rank to all ranks in
|
||||
the parallel group.
|
||||
def all_reduce(tensor: Tensor,
|
||||
parallel_mode: ParallelMode,
|
||||
async_op=False) -> Tensor:
|
||||
op = dist.all_reduce(tensor,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
if async_op:
|
||||
return tensor, op
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
||||
# def scatter(tensor: Tensor, src: int, dim: int,
|
||||
# parallel_mode: ParallelMode) -> Tensor:
|
||||
# """Scatters in a specific dimension from source rank to all ranks in
|
||||
# the parallel group.
|
||||
|
||||
:param tensor: Tensor to be scattered
|
||||
:param dim: The dimension scattering in
|
||||
:param parallel_mode: Parallel group mode used in this communication
|
||||
:type tensor: Tensor
|
||||
:type dim: int
|
||||
:type parallel_mode: ParallelMode
|
||||
:return: The tensor generated by scatter
|
||||
:rtype: Tensor
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
temp = tensor.clone()
|
||||
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
|
||||
return out
|
||||
# :param tensor: Tensor to be scattered
|
||||
# :param dim: The dimension scattering in
|
||||
# :param parallel_mode: Parallel group mode used in this communication
|
||||
# :type tensor: Tensor
|
||||
# :type dim: int
|
||||
# :type parallel_mode: ParallelMode
|
||||
# :return: The tensor generated by scatter
|
||||
# :rtype: Tensor
|
||||
# """
|
||||
# depth = gpc.get_world_size(parallel_mode)
|
||||
# temp = tensor.clone()
|
||||
# dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
|
||||
# rank = gpc.get_local_rank(parallel_mode)
|
||||
# out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
|
||||
# return out
|
||||
|
@ -17,8 +17,6 @@ def _communicate(tensor_send_next=None,
|
||||
recv_next_shape=None,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
up_group=None,
|
||||
down_group=None,
|
||||
dtype=None):
|
||||
"""
|
||||
Adapted from megatron.p2p_communication.
|
||||
@ -59,60 +57,44 @@ def _communicate(tensor_send_next=None,
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
if up_group is None:
|
||||
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
|
||||
if tensor_send_next is not None or recv_next:
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
if down_group is None:
|
||||
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
|
||||
# rank = dist.get_rank()
|
||||
rank = gpc.get_global_rank()
|
||||
|
||||
ops = []
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_op = dist.broadcast(tensor_send_prev,
|
||||
src=rank,
|
||||
group=up_group,
|
||||
async_op=True)
|
||||
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
|
||||
ops.append(send_prev_op)
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_op = dist.broadcast(tensor_recv_prev,
|
||||
src=prev_rank,
|
||||
group=up_group,
|
||||
async_op=True)
|
||||
recv_prev_op = dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)
|
||||
ops.append(recv_prev_op)
|
||||
if tensor_recv_next is not None:
|
||||
recv_next_op = dist.broadcast(tensor_recv_next,
|
||||
src=next_rank,
|
||||
group=down_group,
|
||||
async_op=True)
|
||||
recv_next_op = dist.P2POp(dist.irecv, tensor_recv_next, next_rank)
|
||||
ops.append(recv_next_op)
|
||||
if tensor_send_next is not None:
|
||||
send_next_op = dist.broadcast(tensor_send_next,
|
||||
src=rank,
|
||||
group=down_group,
|
||||
async_op=True)
|
||||
send_next_op = dist.P2POp(dist.isend, tensor_send_next, next_rank)
|
||||
ops.append(send_next_op)
|
||||
for req in ops:
|
||||
req.wait()
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
torch.cuda.synchronize()
|
||||
return tensor_recv_prev, tensor_recv_next
|
||||
|
||||
|
||||
def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
|
||||
def recv_forward(input_tensor_shape, prev_rank=None):
|
||||
"""Receives the input tensor from the previous member in pipeline.
|
||||
|
||||
|
||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||
:param prev_rank: The rank of the source of the tensor
|
||||
:param up_group: Communication group including the previous member in pipeline parallel group
|
||||
:type input_tensor_shape: torch.Size
|
||||
:type prev_rank: int, optional
|
||||
:type up_group: ProcessGroup, optional
|
||||
:return: The input tensor in forward step
|
||||
:rtype: Tensor
|
||||
"""
|
||||
@ -121,20 +103,17 @@ def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
|
||||
else:
|
||||
input_tensor, _ = _communicate(recv_prev=True,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
up_group=up_group)
|
||||
prev_rank=prev_rank)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(output_grad_shape, next_rank=None, down_group=None):
|
||||
def recv_backward(output_grad_shape, next_rank=None):
|
||||
"""Receives the grad tensor from the next member in pipeline.
|
||||
|
||||
|
||||
:param output_grad_shape: The shape of the tensor to be recieved
|
||||
:param next_rank: The rank of the source of the tensor
|
||||
:param down_group: Communication group including the next member in pipeline parallel group
|
||||
:type output_grad_shape: torch.Size
|
||||
:type next_rank: int, optional
|
||||
:type down_group: ProcessGroup, optional
|
||||
:return: The grad of output tensor in forward step
|
||||
:rtype: Tensor
|
||||
"""
|
||||
@ -143,56 +122,44 @@ def recv_backward(output_grad_shape, next_rank=None, down_group=None):
|
||||
else:
|
||||
_, output_tensor_grad = _communicate(recv_next=True,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
down_group=down_group)
|
||||
next_rank=next_rank)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward(output_tensor,
|
||||
next_rank=None,
|
||||
down_group=None):
|
||||
def send_forward(output_tensor, next_rank=None):
|
||||
"""Sends the input tensor to the next member in pipeline.
|
||||
|
||||
|
||||
:param output_tensor: Tensor to be sent
|
||||
:param next_rank: The rank of the recipient of the tensor
|
||||
:param down_group: Communication group including the next member in pipeline parallel group
|
||||
:type output_tensor: Tensor
|
||||
:type next_rank: int, optional
|
||||
:type down_group: ProcessGroup, optional
|
||||
"""
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
_communicate(tensor_send_next=output_tensor,
|
||||
next_rank=next_rank,
|
||||
down_group=down_group)
|
||||
next_rank=next_rank)
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad,
|
||||
prev_rank=None,
|
||||
up_group=None):
|
||||
def send_backward(input_tensor_grad, prev_rank=None):
|
||||
"""Sends the grad tensor to the previous member in pipeline.
|
||||
|
||||
|
||||
:param input_tensor_grad: Tensor to be sent
|
||||
:param prev_rank: The rank of the recipient of the tensor
|
||||
:param up_group: Communication group including the previous member in pipeline parallel group
|
||||
:type input_tensor_grad: Tensor
|
||||
:type prev_rank: int, optional
|
||||
:type up_group: ProcessGroup, optional
|
||||
"""
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
_communicate(tensor_send_prev=input_tensor_grad,
|
||||
prev_rank=prev_rank,
|
||||
up_group=up_group)
|
||||
prev_rank=prev_rank)
|
||||
|
||||
|
||||
def send_forward_recv_backward(output_tensor,
|
||||
output_grad_shape,
|
||||
recv_next=True,
|
||||
next_rank=None,
|
||||
down_group=None):
|
||||
next_rank=None):
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next member in pipeline, while recieves the grad tensor from the
|
||||
next member in pipeline.
|
||||
|
||||
|
||||
:param output_tensor: Tensor to be sent
|
||||
:param output_grad_shape: The shape of the tensor to be recieved
|
||||
:type output_tensor: Tensor
|
||||
@ -206,20 +173,18 @@ def send_forward_recv_backward(output_tensor,
|
||||
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
down_group=down_group)
|
||||
next_rank=next_rank)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_backward_recv_forward(input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
up_group=None):
|
||||
prev_rank=None):
|
||||
"""Batched communication operation. Sends the grad tensor to the
|
||||
previous member in pipeline, while recieves the input tensor from the
|
||||
previous member in pipeline.
|
||||
|
||||
|
||||
:param input_tensor_grad: Tensor to be sent
|
||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||
:type input_tensor_grad: Tensor
|
||||
@ -233,8 +198,7 @@ def send_backward_recv_forward(input_tensor_grad,
|
||||
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
up_group=up_group)
|
||||
prev_rank=prev_rank)
|
||||
return input_tensor
|
||||
|
||||
|
||||
@ -242,13 +206,11 @@ def send_forward_recv_forward(output_tensor,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
up_group=None,
|
||||
down_group=None):
|
||||
next_rank=None):
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next member in pipeline, while recieves the input tensor from the
|
||||
previous member in pipeline.
|
||||
|
||||
|
||||
:param output_tensor: Tensor to be sent
|
||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||
:type output_tensor: Tensor
|
||||
@ -260,9 +222,7 @@ def send_forward_recv_forward(output_tensor,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
up_group=up_group,
|
||||
down_group=down_group)
|
||||
next_rank=next_rank)
|
||||
return input_tensor
|
||||
|
||||
|
||||
@ -270,13 +230,11 @@ def send_backward_recv_backward(input_tensor_grad,
|
||||
output_grad_shape,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
up_group=None,
|
||||
down_group=None):
|
||||
next_rank=None):
|
||||
"""Batched communication operation. Sends the grad tensor to the
|
||||
previous member in pipeline, while recieves the grad tensor from the
|
||||
next member in pipeline.
|
||||
|
||||
|
||||
:param input_tensor_grad: Tensor to be sent
|
||||
:param output_grad_shape: The shape of the tensor to be recieved
|
||||
:type input_tensor_grad: Tensor
|
||||
@ -288,9 +246,7 @@ def send_backward_recv_backward(input_tensor_grad,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
up_group=up_group,
|
||||
down_group=down_group)
|
||||
next_rank=next_rank)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
@ -301,13 +257,11 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
||||
recv_prev=True,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
up_group=None,
|
||||
down_group=None):
|
||||
next_rank=None):
|
||||
"""Batched communication operation. Sends the input tensor to the next and
|
||||
the grad tensor to the previous, while recieves the grad tensor from the
|
||||
next and the input tensor from the previous.
|
||||
|
||||
|
||||
:param output_tensor: Tensor sent to the next
|
||||
:param input_tensor_grad: Tensor sent to the previous
|
||||
:param input_tensor_shape: The shape of the tensor recieved from the previous
|
||||
@ -327,7 +281,5 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
up_group=up_group,
|
||||
down_group=down_group)
|
||||
next_rank=next_rank)
|
||||
return input_tensor, output_tensor_grad
|
||||
|
@ -6,7 +6,7 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def send_tensor_meta(tensor, need_meta=True, down_group=None):
|
||||
def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
||||
"""Sends tensor meta information before sending a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be sent before communications. This function
|
||||
@ -14,31 +14,34 @@ def send_tensor_meta(tensor, need_meta=True, down_group=None):
|
||||
|
||||
:param tensor: Tensor to be sent
|
||||
:param need_meta: If False, meta information won't be sent
|
||||
:param down_group: Communication group including the next member in pipeline parallel group
|
||||
:param next_rank: The rank of the next member in pipeline parallel group
|
||||
:type tensor: Tensor
|
||||
:type need_meta: bool, optional
|
||||
:type down_group: ProcessGroup, optional
|
||||
:type next_rank: int
|
||||
:return: False
|
||||
:rtype: bool
|
||||
"""
|
||||
if need_meta:
|
||||
rank = gpc.get_global_rank()
|
||||
|
||||
if down_group is None:
|
||||
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
||||
|
||||
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
|
||||
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
|
||||
|
||||
dist.broadcast(send_ndims, src=rank, group=down_group)
|
||||
dist.broadcast(send_shape, src=rank, group=down_group)
|
||||
ops = [
|
||||
dist.P2POp(dist.isend, send_ndims, next_rank),
|
||||
dist.P2POp(dist.isend, send_shape, next_rank)
|
||||
]
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
|
||||
def recv_tensor_meta(tensor_shape, prev_rank=None):
|
||||
"""Recieves tensor meta information before recieving a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be recieved before communications. This function
|
||||
@ -46,27 +49,21 @@ def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
|
||||
|
||||
:param tensor_shape: The shape of the tensor to be recieved
|
||||
:param prev_rank: The rank of the source of the tensor
|
||||
:param up_group: Communication group including the previous member in pipeline parallel group
|
||||
:type tensor_shape: torch.Size
|
||||
:type prev_rank: int, optional
|
||||
:type up_group: ProcessGroup, optional
|
||||
:return: The shape of the tensor to be recieved
|
||||
:rtype: torch.Size
|
||||
"""
|
||||
if tensor_shape is None:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
if up_group is None:
|
||||
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
||||
|
||||
recv_ndims = torch.empty((), **tensor_kwargs)
|
||||
dist.broadcast(recv_ndims, src=prev_rank, group=up_group)
|
||||
|
||||
dist.recv(recv_ndims, prev_rank)
|
||||
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
|
||||
dist.broadcast(recv_shape, src=prev_rank, group=up_group)
|
||||
dist.recv(recv_shape, prev_rank)
|
||||
|
||||
tensor_shape = torch.Size(recv_shape)
|
||||
|
||||
|
@ -25,7 +25,11 @@ TESSERACT_DEP = 'TESSERACT_DEP'
|
||||
|
||||
# 3D parallel
|
||||
DEPTH_3D = 'DEPTH_3D'
|
||||
INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
|
||||
WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
|
||||
OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'
|
||||
|
||||
# Tensor parallel attributes
|
||||
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
|
||||
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL]
|
||||
NUM_PARTITIONS = 'num_partitions'
|
||||
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
|
||||
|
@ -1,5 +1,5 @@
|
||||
from .config import Config
|
||||
from .config import Config, ConfigException
|
||||
from .parallel_context import ParallelContext
|
||||
from .parallel_context import ParallelMode
|
||||
from .parallel_mode import ParallelMode
|
||||
from .process_group_initializer import *
|
||||
from .random import *
|
||||
|
@ -1,70 +0,0 @@
|
||||
import math
|
||||
|
||||
|
||||
def set_parallel_size(obj, config: dict, key: str, attr_name: str):
|
||||
if key in config:
|
||||
ele = config[key]
|
||||
if isinstance(ele, int):
|
||||
setattr(obj, attr_name, ele)
|
||||
elif isinstance(ele, dict):
|
||||
setattr(obj, attr_name, ele['size'])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Parallel configuration does not support this kind of argument, please use int or dict"
|
||||
)
|
||||
|
||||
|
||||
def add_tensor_pg(pg_init, mode, size, depth=None):
|
||||
if mode == '1d':
|
||||
pg_init.append(dict(
|
||||
type='Initializer1D',
|
||||
parallel_size=size
|
||||
))
|
||||
elif mode == '2d':
|
||||
dim = math.floor(math.sqrt(size))
|
||||
pg_init.append(dict(
|
||||
type='Initializer2D_Col',
|
||||
summa_dim=dim
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='Initializer2D_Row',
|
||||
summa_dim=dim
|
||||
))
|
||||
elif mode == '2.5d':
|
||||
dim = math.floor(math.sqrt(size // depth))
|
||||
pg_init.append(dict(
|
||||
type='Initializer_Tesseract_ROW',
|
||||
tesseract_dim=dim,
|
||||
tesseract_dep=depth
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='Initializer_Tesseract_COL',
|
||||
tesseract_dim=dim,
|
||||
tesseract_dep=depth
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='Initializer_Tesseract_DEP',
|
||||
tesseract_dim=dim,
|
||||
tesseract_dep=depth
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='Initializer_Tesseract_XZ',
|
||||
tesseract_dim=dim,
|
||||
tesseract_dep=depth
|
||||
))
|
||||
elif mode == '3d':
|
||||
dim = math.floor(math.pow(size, 1.0 / 3.0) + 0.5)
|
||||
pg_init.append(dict(
|
||||
type='ParallelInitializer3D_Input',
|
||||
depth=dim
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='ParallelInitializer3D_Weight',
|
||||
depth=dim
|
||||
))
|
||||
pg_init.append(dict(
|
||||
type='ParallelInitializer3D_Output',
|
||||
depth=dim
|
||||
))
|
||||
else:
|
||||
raise NotImplementedError("This kind of tensor splitting has not been implemented yet")
|
@ -97,3 +97,7 @@ class Config(dict):
|
||||
sys.path.pop(0)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class ConfigException(Exception):
|
||||
pass
|
||||
|
@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
@ -11,8 +10,8 @@ import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
|
||||
from colossalai.context.config import Config
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from ._utils import set_parallel_size
|
||||
from .parallel_mode import ParallelMode
|
||||
from .random import add_seed, get_seeds, set_mode
|
||||
|
||||
@ -21,11 +20,24 @@ class ParallelContext:
|
||||
"""This class provides interface functions for users to get the parallel context,
|
||||
such as the global rank, the local rank, the world size, etc. of each device.
|
||||
|
||||
:param args: The distributed arguments in the system
|
||||
:type args: dict
|
||||
"""
|
||||
|
||||
def __init__(self, args=None):
|
||||
__instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance():
|
||||
if ParallelContext.__instance is None:
|
||||
ParallelContext()
|
||||
return ParallelContext.__instance
|
||||
|
||||
def __init__(self):
|
||||
# create a singleton instance
|
||||
if ParallelContext.__instance is not None:
|
||||
raise Exception(
|
||||
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
|
||||
else:
|
||||
ParallelContext.__instance = self
|
||||
|
||||
# distributed settings
|
||||
self._global_ranks = dict()
|
||||
self._local_ranks = dict()
|
||||
@ -34,7 +46,6 @@ class ParallelContext:
|
||||
self._ranks_in_group = dict()
|
||||
|
||||
# load config from file
|
||||
self._dist_args = args
|
||||
self._config = None
|
||||
|
||||
# default 3D parallel args, will be overwritten during process group intialization
|
||||
@ -43,10 +54,22 @@ class ParallelContext:
|
||||
self.pipeline_parallel_size = 1
|
||||
self.tensor_parallel_size = 1
|
||||
|
||||
# logging
|
||||
self._verbose = False
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def verbose(self):
|
||||
return self._verbose
|
||||
|
||||
@verbose.setter
|
||||
def verbose(self, verbose_: bool):
|
||||
self._verbose = verbose_
|
||||
|
||||
def load_config(self, config: Union[dict, str]):
|
||||
"""Loads the configuration from either a dict or a file.
|
||||
|
||||
@ -62,14 +85,6 @@ class ParallelContext:
|
||||
else:
|
||||
raise TypeError("Invalid type for config, only dictionary or string is supported")
|
||||
|
||||
def set_dist_args(self, args):
|
||||
"""Sets the distributed arguments.
|
||||
|
||||
:param args: The distributed arguments in the system
|
||||
:type args: dict
|
||||
"""
|
||||
self._dist_args = args
|
||||
|
||||
@staticmethod
|
||||
def _check_parallel_mode(parallel_mode: ParallelMode):
|
||||
assert isinstance(parallel_mode, ParallelMode)
|
||||
@ -268,32 +283,36 @@ class ParallelContext:
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._ranks_in_group[parallel_mode] = ranks
|
||||
|
||||
def init_global_dist(self, addr=None, port=None):
|
||||
"""Initializes the global distributed environment.
|
||||
|
||||
:param addr: The IP address of the current device
|
||||
:type addr: str, optional
|
||||
:param port: The port to be used in the system of the current device
|
||||
:type port: int, optional
|
||||
def init_global_dist(self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
host: str,
|
||||
port: int
|
||||
):
|
||||
"""Initializes the global distributed environment
|
||||
:param rank: rank for the default process group
|
||||
:type rank: int
|
||||
:param world_size: world size of the default process group
|
||||
:type world_size: int
|
||||
:param host: the master address for distributed training
|
||||
:type host: str
|
||||
:param port: the master port for distributed training
|
||||
:type port: str
|
||||
:param backend: backend for torch.distributed
|
||||
:type backend: str
|
||||
"""
|
||||
# get config
|
||||
rank = self._dist_args.local_rank
|
||||
world_size = self._dist_args.world_size
|
||||
# default env config, overwrite by exporting
|
||||
# them in your bash script
|
||||
addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr
|
||||
port = os.getenv('MASTER_PORT', '8008') if port is None else port
|
||||
init_method = f'tcp://{addr}:{port}'
|
||||
|
||||
dist.init_process_group(backend=self._dist_args.backend,
|
||||
rank=rank,
|
||||
# initialize the default process group
|
||||
init_method = f'tcp://{host}:{port}'
|
||||
dist.init_process_group(rank=rank,
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
init_method=init_method)
|
||||
|
||||
# None will give the default global process group for pytorch dist operations
|
||||
self._register_dist(rank, world_size, None,
|
||||
list(range(world_size)), ParallelMode.GLOBAL)
|
||||
self._global_ranks[ParallelMode.GLOBAL] = rank
|
||||
self.add_global_rank(ParallelMode.GLOBAL, rank)
|
||||
|
||||
def _register_dist(self, local_rank, world_size,
|
||||
process_group, ranks_in_group, mode):
|
||||
@ -312,7 +331,20 @@ class ParallelContext:
|
||||
pps = self.pipeline_parallel_size
|
||||
tps = self.tensor_parallel_size
|
||||
ws = self.world_size
|
||||
assert ws == dps * pps * tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
|
||||
assert ws == dps * pps * \
|
||||
tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
|
||||
|
||||
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
||||
if key in config:
|
||||
ele = config[key]
|
||||
if isinstance(ele, int):
|
||||
setattr(self, attr_name, ele)
|
||||
elif isinstance(ele, dict):
|
||||
setattr(self, attr_name, ele['size'])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Parallel configuration does not support this kind of argument, please use int or dict"
|
||||
)
|
||||
|
||||
def init_parallel_groups(self):
|
||||
"""Initializes the parallel groups.
|
||||
@ -325,21 +357,20 @@ class ParallelContext:
|
||||
world_size = self.get_world_size(ParallelMode.GLOBAL)
|
||||
self.world_size = world_size
|
||||
|
||||
assert hasattr(self.config, 'parallel'), 'Expected the field parallel to be present in the config file'
|
||||
|
||||
# set parallel size as attributes for global context
|
||||
parallel_config = self.config.parallel
|
||||
set_parallel_size(self, parallel_config, 'pipeline',
|
||||
'pipeline_parallel_size')
|
||||
set_parallel_size(self, parallel_config, 'tensor',
|
||||
'tensor_parallel_size')
|
||||
parallel_config = self.config.get('parallel', None)
|
||||
if parallel_config is not None:
|
||||
self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size')
|
||||
self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size')
|
||||
|
||||
# the user should not set the data parallel size manually
|
||||
# instead, it should be calculated based on other parallel config
|
||||
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
|
||||
|
||||
# get the tensor parallel mode and check
|
||||
tensor_parallel_mode = parallel_config['tensor'].get('mode', None)
|
||||
tensor_parallel_mode = None
|
||||
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
|
||||
tensor_parallel_mode = parallel_config['tensor']['mode']
|
||||
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
||||
self.check_sanity()
|
||||
|
||||
@ -400,23 +431,21 @@ class ParallelContext:
|
||||
# destroy global process group
|
||||
dist.destroy_process_group()
|
||||
|
||||
def set_device(self):
|
||||
def set_device(self, device_ordinal: int = None):
|
||||
"""Sets distributed processes to be bound to devices.
|
||||
"""
|
||||
devices_per_node = torch.cuda.device_count()
|
||||
global_rank = self.get_global_rank()
|
||||
device = global_rank % devices_per_node
|
||||
torch.cuda.set_device(device)
|
||||
print(f'process rank {global_rank} is bound to device {device}')
|
||||
if device_ordinal is None:
|
||||
devices_per_node = torch.cuda.device_count()
|
||||
device_ordinal = global_rank % devices_per_node
|
||||
|
||||
def set_seed(self):
|
||||
torch.cuda.set_device(device_ordinal)
|
||||
if self._verbose:
|
||||
self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}')
|
||||
|
||||
def set_seed(self, seed: int):
|
||||
"""Sets seeds for all random libraries.
|
||||
"""
|
||||
if hasattr(self.config, 'seed'):
|
||||
seed = getattr(self.config, 'seed')
|
||||
else:
|
||||
seed = 2 # default seed
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
@ -444,11 +473,18 @@ class ParallelContext:
|
||||
seeds = get_seeds()
|
||||
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
|
||||
|
||||
print(f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
||||
f"the default parallel seed is {ParallelMode.DATA}.", flush=True)
|
||||
if self._verbose:
|
||||
self._logger.info(
|
||||
f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
||||
f"the default parallel seed is {ParallelMode.DATA}.",
|
||||
ranks=[0])
|
||||
else:
|
||||
print(f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, pytorch: {seed}", flush=True)
|
||||
print('WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
|
||||
flush=True)
|
||||
if self._verbose:
|
||||
self._logger.info(
|
||||
f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, pytorch: {seed}",
|
||||
ranks=[0])
|
||||
self._logger.info(
|
||||
'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
|
||||
ranks=[0])
|
||||
|
@ -4,7 +4,6 @@
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context import Config
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
@ -8,7 +8,6 @@ import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
|
||||
from colossalai.context import Config
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
@ -42,8 +41,6 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
|
||||
tesseract_dep: int,
|
||||
*args):
|
||||
super(Initializer_2p5D_ROW, self).__init__(*args)
|
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dep = tesseract_dep
|
||||
self.tesseract_dim = tesseract_dim
|
||||
@ -66,7 +63,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
|
||||
for j in range(self.tesseract_dim):
|
||||
for k in range(self.tesseract_dep):
|
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
|
||||
j + self.tesseract_dim * k) for i in range(self.tesseract_dim)]
|
||||
j + self.tesseract_dim * k) for i in range(self.tesseract_dim)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
@ -81,13 +78,12 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
|
||||
class Initializer_2p5D_Col(ProcessGroupInitializer):
|
||||
'''2p5d tensor parallel initialization among cols.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
*args):
|
||||
super(Initializer_2p5D_Col, self).__init__(*args)
|
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dep = tesseract_dep
|
||||
self.tesseract_dim = tesseract_dim
|
||||
@ -110,7 +106,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
|
||||
for i in range(self.tesseract_dim):
|
||||
for k in range(self.tesseract_dep):
|
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
|
||||
j + self.tesseract_dim * k) for j in range(self.tesseract_dim)]
|
||||
j + self.tesseract_dim * k) for j in range(self.tesseract_dim)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
@ -125,13 +121,12 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
|
||||
class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
||||
'''2p5D tensor parallel initialization among depths.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
*args):
|
||||
super(Initializer_2p5D_Dep, self).__init__(*args)
|
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dep = tesseract_dep
|
||||
self.tesseract_dim = tesseract_dim
|
||||
@ -154,7 +149,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
||||
for i in range(self.tesseract_dim):
|
||||
for j in range(self.tesseract_dim):
|
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
|
||||
j + self.tesseract_dim * k) for k in range(self.tesseract_dep)]
|
||||
j + self.tesseract_dim * k) for k in range(self.tesseract_dep)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
@ -170,13 +165,12 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
||||
class Initializer_2p5D_XZ(ProcessGroupInitializer):
|
||||
'''2p5d tensor parallel initialization among cols times dep.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
*args):
|
||||
super(Initializer_2p5D_XZ, self).__init__(*args)
|
||||
|
||||
self.tensor_parallel_size = gpc.tensor_parallel_size
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dep = tesseract_dep
|
||||
self.tesseract_dim = tesseract_dim
|
||||
@ -198,8 +192,8 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
|
||||
for h in range(self.num_group):
|
||||
for i in range(self.tesseract_dim):
|
||||
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
|
||||
j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in
|
||||
range(self.tesseract_dim)]
|
||||
j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in
|
||||
range(self.tesseract_dim)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
|
@ -5,7 +5,7 @@ import math
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import DEPTH_3D
|
||||
from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
|
||||
from ..parallel_mode import ParallelMode
|
||||
@ -18,7 +18,7 @@ def _check_depth_env_var(depth):
|
||||
|
||||
if env_depth:
|
||||
assert int(env_depth) == depth, \
|
||||
'SUMMA_DIM has been set in the current environment and ' \
|
||||
'DEPTH_3D has been set in the current environment and ' \
|
||||
'does not match with the value passed to this initialized'
|
||||
else:
|
||||
os.environ[DEPTH_3D] = str(depth)
|
||||
@ -43,6 +43,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_INPUT
|
||||
os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D
|
||||
|
||||
for h in range(self.num_group):
|
||||
for i in range(self.depth):
|
||||
@ -82,6 +83,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_WEIGHT
|
||||
os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D
|
||||
|
||||
for h in range(self.num_group):
|
||||
for k in range(self.depth):
|
||||
@ -121,6 +123,7 @@ class Initializer_3D_Output(ProcessGroupInitializer):
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_OUTPUT
|
||||
os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D
|
||||
|
||||
for h in range(self.num_group):
|
||||
for i in range(self.depth):
|
||||
|
@ -3,14 +3,4 @@
|
||||
|
||||
from colossalai.context import ParallelContext
|
||||
|
||||
global_context = ParallelContext()
|
||||
|
||||
|
||||
def set_global_context(context: ParallelContext):
|
||||
'''Reset global context to be identical to a given :class:ParallelContext.
|
||||
|
||||
:param context: Parallel context to generate our global parallel context.
|
||||
:type context: ParallelContext
|
||||
'''
|
||||
global global_context
|
||||
global_context = context
|
||||
global_context = ParallelContext.get_instance()
|
||||
|
@ -1,7 +1,5 @@
|
||||
from ._base_engine import Engine
|
||||
from .gradient_handler import *
|
||||
from .schedule import *
|
||||
from .amp import *
|
||||
|
||||
|
||||
__all__ = ['Engine']
|
||||
|
@ -1,17 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
import torch
|
||||
from typing import List
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.builder import build_gradient_handler
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from .schedule import BaseSchedule
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import is_using_ddp, is_using_pp
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Engine:
|
||||
@ -20,74 +20,40 @@ class Engine:
|
||||
It controls a iteration in training.
|
||||
|
||||
:param model: The neural network model
|
||||
:type model: ``torch.nn.Module``
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:param step_schedule: Running schedule in :meth:`step`
|
||||
:param gradient_accumulation: Steps of gradient accumulation
|
||||
:type optimizer: ``torch.optim.Optimizer``
|
||||
:param criterion: Loss function for calculating loss
|
||||
:type criterion: ``torch.nn.modules.loss._Loss``
|
||||
:param gradient_clipping: The norm of gradient clipping
|
||||
:type model: Module
|
||||
:type optimizer: Optimizer
|
||||
:type step_schedule: BaseSchedule, optional
|
||||
:type gradient_accumulation: int, optional
|
||||
:type gradient_clipping: float, optional
|
||||
:param verbose: whether to display log info
|
||||
:type verbose: bool
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: _Loss,
|
||||
step_schedule: BaseSchedule,
|
||||
gradient_handlers: list = None,
|
||||
gradient_accumulation: int = 1,
|
||||
gradient_clipping: float = 0.0,
|
||||
gradient_handlers: List = None,
|
||||
clip_grad_norm: float = 0.0,
|
||||
verbose: bool = True
|
||||
):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._criterion = criterion
|
||||
self._schedule = step_schedule
|
||||
|
||||
# schedule initialize
|
||||
self._schedule.initialize(model, optimizer)
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
self._verbose = verbose
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
# state
|
||||
self.training = True # default
|
||||
|
||||
# gradient accumulation
|
||||
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
|
||||
self._grad_accum_size = gradient_accumulation
|
||||
self._grad_clip = gradient_clipping
|
||||
self._logger = get_global_dist_logger()
|
||||
|
||||
# build gradient handler
|
||||
self._gradient_handlers = []
|
||||
|
||||
if gradient_handlers is not None:
|
||||
assert isinstance(gradient_handlers, list), \
|
||||
f'argument gradient_handler_cfg expected type list, ' \
|
||||
f'but got type {type(gradient_handlers)}'
|
||||
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
gradient_handlers = [dict(type='ZeROGradientHandler')]
|
||||
self._logger.info(
|
||||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
|
||||
ParallelMode.DATA) > 1:
|
||||
gradient_handlers = [dict(type='DataParallelGradientHandler')]
|
||||
self._logger.info(
|
||||
"Data parallel training is detected, DataParallelGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
|
||||
if gradient_handlers is None:
|
||||
self._logger.warning(
|
||||
"No gradient handler is set up, please make sure you do not need "
|
||||
"to all-reduce the gradients after a training step.",
|
||||
ranks=[0])
|
||||
if gradient_handlers:
|
||||
self._gradient_handlers = gradient_handlers
|
||||
else:
|
||||
for cfg in gradient_handlers:
|
||||
handler = build_gradient_handler(cfg, model, optimizer)
|
||||
self._gradient_handlers.append(handler)
|
||||
self._gradient_handlers = []
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
@ -105,11 +71,27 @@ class Engine:
|
||||
def schedule(self):
|
||||
return self._schedule
|
||||
|
||||
@property
|
||||
def gradient_accumulation(self):
|
||||
return self._grad_accum_size
|
||||
def zero_grad(self):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def handle_gradient(self):
|
||||
def step(self):
|
||||
self._all_reduce_gradients()
|
||||
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
|
||||
self.optimizer.step()
|
||||
|
||||
def backward(self, loss: Tensor):
|
||||
return self.optimizer.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
return self.optimizer.backward_by_grad(tensor, grad)
|
||||
|
||||
def calc_loss(self, *args, **kwargs):
|
||||
return self.criterion(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def _all_reduce_gradients(self):
|
||||
"""Handles all-reduce operations of gradients across different parallel groups.
|
||||
"""
|
||||
for handler in self._gradient_handlers:
|
||||
@ -126,51 +108,3 @@ class Engine:
|
||||
"""
|
||||
self.training = False
|
||||
self._model.eval()
|
||||
|
||||
def step(self,
|
||||
data_iter,
|
||||
is_last_iteration: bool = False,
|
||||
return_loss=True):
|
||||
"""A running step based on the schedule. Usually, it runs a training or
|
||||
evaluation over a batch of dataset.
|
||||
|
||||
:param data_iter: Data iterator of the dataset
|
||||
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
|
||||
:param return_loss: loss will be returned if True
|
||||
:type data_iter: Iterator
|
||||
:type is_last_iteration: bool, optional
|
||||
:type return_loss: bool, optional
|
||||
:return: (output, lablel, loss)
|
||||
"""
|
||||
if self.training:
|
||||
self._optimizer.zero_grad()
|
||||
|
||||
# differentiate training and eval with grad accum
|
||||
if self.training:
|
||||
for i in range(self._grad_accum_size):
|
||||
output, label, loss = self._schedule.forward_backward_step(
|
||||
data_iter, self._model, self._criterion, self._optimizer,
|
||||
forward_only=False,
|
||||
grad_accum_size=self._grad_accum_size,
|
||||
return_loss=return_loss)
|
||||
|
||||
if i == self._grad_accum_size - 1:
|
||||
# all reduce gradients
|
||||
self.handle_gradient()
|
||||
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
|
||||
else:
|
||||
output, label, loss = self._schedule.forward_backward_step(
|
||||
data_iter, self._model, self._criterion, self._optimizer,
|
||||
forward_only=True,
|
||||
grad_accum_size=1,
|
||||
return_loss=return_loss)
|
||||
|
||||
# consume the remaining dataset left out due to gradient accumulation
|
||||
if is_last_iteration:
|
||||
while True:
|
||||
try:
|
||||
_ = next(data_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
return output, label, loss
|
||||
|
@ -1,2 +0,0 @@
|
||||
from .grad_scaler import GradScaler
|
||||
from .amp_type import AMP_TYPE
|
@ -1,5 +1,5 @@
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ._no_pipeline import NoPipelineSchedule
|
||||
from ._pipeline import PipelineSchedule
|
||||
from ._pipeline_schedule import PipelineSchedule
|
||||
from ._non_pipeline_schedule import NonPipelineSchedule
|
||||
|
||||
__all__ = ['BaseSchedule', 'NoPipelineSchedule', 'PipelineSchedule']
|
||||
__all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule']
|
||||
|
@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from torch import Tensor
|
||||
from typing import Iterable, Union, List, Callable
|
||||
from .._base_engine import Engine
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@ -18,8 +20,9 @@ class BaseSchedule(ABC):
|
||||
control of FP16 in class schedule.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_global_dist_logger()
|
||||
def __init__(self, batch_data_process_func: Callable = None):
|
||||
self.logger = get_dist_logger()
|
||||
self.batch_data_process_func = batch_data_process_func
|
||||
|
||||
@staticmethod
|
||||
def _move_tensor(element):
|
||||
@ -35,6 +38,11 @@ class BaseSchedule(ABC):
|
||||
data = data.to(get_current_device()).detach()
|
||||
return data
|
||||
|
||||
def _to_list(self, data):
|
||||
if torch.is_tensor(data):
|
||||
return [data]
|
||||
return data
|
||||
|
||||
def load_batch(self, data_iter):
|
||||
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||
already in the same GPU as where the model's.
|
||||
@ -44,46 +52,34 @@ class BaseSchedule(ABC):
|
||||
"""
|
||||
if data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
data, label = next(data_iter)
|
||||
batch_data = next(data_iter)
|
||||
|
||||
if self.batch_data_process_func:
|
||||
data, label = self.batch_data_process_func(batch_data)
|
||||
else:
|
||||
data, label = batch_data
|
||||
|
||||
data, label = self._to_list(data), self._to_list(label)
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
|
||||
def initialize(self, model, optimizer):
|
||||
"""Initializes the model and the optimizer before training.
|
||||
This is often used in FP16 training.
|
||||
|
||||
:param model: The neural network model
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
def pre_processing(self, engine: Engine):
|
||||
"""To perform actions before running the schedule.
|
||||
"""
|
||||
return model, optimizer
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(self,
|
||||
data_iter,
|
||||
model,
|
||||
criterion,
|
||||
optimizer=None,
|
||||
forward_only=False,
|
||||
grad_accum_size: int = 1,
|
||||
return_loss=True):
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True
|
||||
):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
:param data_iter: Data iterator of the dataset
|
||||
:param model: Model used in training or evaluation
|
||||
:param optimizer: Optimizer used in training or evaluation
|
||||
:param criterion: Loss function
|
||||
:param engine: Colossalai training engine
|
||||
:param inputs: input data
|
||||
:param labels: ground truth
|
||||
:param forward_only: If True, the process won't include backward
|
||||
:param grad_accum_size: Steps of gradient accumulation
|
||||
:param return_loss: If False, the loss won't be returned
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
|
||||
"""Updates the parameters with the optimizer.
|
||||
|
||||
:param model: The neural network model
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:param grad_clipping: The norm of gradient clipping
|
||||
:type grad_clipping: float, optional
|
||||
"""
|
||||
pass
|
||||
pass
|
@ -1,188 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
try:
|
||||
import apex.amp as apex_amp
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch.cuda.amp as torch_amp
|
||||
except:
|
||||
pass
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ._utils import convert_to_fp16, convert_to_fp32
|
||||
from ..amp import AMP_TYPE, GradScaler
|
||||
|
||||
|
||||
class NoPipelineSchedule(BaseSchedule):
|
||||
"""A helper schedule class for no pipeline parallelism running environment.
|
||||
During one process, it loads a batch of dataset and feeds it to the model.
|
||||
After getting the output and calculating the loss, it will use :meth:`step`
|
||||
to update the parameters if it is in training mode.
|
||||
|
||||
:param amp_type: The type of automatic mixed precision
|
||||
:param amp_config: The configuration of automatic mixed procision
|
||||
:type amp_type: AMP_TYPE
|
||||
:type amp_config: dict
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
amp_type: AMP_TYPE = None,
|
||||
amp_config: dict = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# mixed precision training
|
||||
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
||||
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
||||
|
||||
self.use_zero_level_2_3 = False
|
||||
|
||||
if amp_type is not None:
|
||||
self.fp16 = True
|
||||
self.amp_type = amp_type
|
||||
|
||||
if amp_config is not None:
|
||||
assert isinstance(amp_config, dict), \
|
||||
f'expected argument fp16_config to be type dictionary, but got {type(amp_config)}'
|
||||
|
||||
if self.amp_type == AMP_TYPE.TORCH:
|
||||
# torch apex
|
||||
if amp_config is None:
|
||||
amp_config = dict()
|
||||
self.amp_cfg = amp_config
|
||||
elif self.amp_type == AMP_TYPE.APEX:
|
||||
# apex amp
|
||||
if amp_config is None:
|
||||
amp_config = dict(opt_level='O2')
|
||||
self.logger.warning(
|
||||
'apex is deprecated, please consider using torch.cuda.amp instead.'
|
||||
)
|
||||
self.amp_cfg = amp_config
|
||||
elif self.amp_type == AMP_TYPE.PARALLEL:
|
||||
# use fp16 optimizer for tensor parallel training
|
||||
if amp_config is None:
|
||||
amp_config = dict()
|
||||
self.amp_cfg = amp_config
|
||||
else:
|
||||
self.fp16 = False
|
||||
self.amp_type = None
|
||||
|
||||
def initialize(self, model: nn.Module, optimizer: Optimizer):
|
||||
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
self.use_zero_level_2_3 = True
|
||||
assert self.amp_type != AMP_TYPE.PARALLEL, \
|
||||
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
|
||||
|
||||
if self.fp16:
|
||||
if self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
|
||||
elif self.amp_type == AMP_TYPE.APEX:
|
||||
model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
def forward_backward_step(self,
|
||||
data_iter: Iterable,
|
||||
model: nn.Module,
|
||||
criterion: nn.modules.loss._Loss,
|
||||
optimizer: Optimizer = None,
|
||||
forward_only: bool = False,
|
||||
grad_accum_size: int = 1,
|
||||
return_loss: bool = True):
|
||||
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
||||
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
|
||||
:param model: Model for training and inference
|
||||
:param criterion: Loss function for training
|
||||
:param optimizer: Optimizer used for training
|
||||
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
|
||||
:param grad_accum_size: The number of iterations for gradient accumulation
|
||||
:param return_loss: Loss will be returned if True
|
||||
:type data_iter: Iterator
|
||||
:type model: torch.nn.Module
|
||||
:type criterion: torch.nn.modules.loss._Loss
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:type forward_only: bool, optional
|
||||
:type grad_accum_size: int
|
||||
:type return_loss: bool, optional
|
||||
:return: (output, label, loss)
|
||||
"""
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
|
||||
data, label = self.load_batch(data_iter)
|
||||
loss = None
|
||||
|
||||
# forward
|
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||
with torch_amp.autocast():
|
||||
output = model(*data)
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
if return_loss:
|
||||
loss = criterion(*output, *label)
|
||||
else:
|
||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||
data = convert_to_fp16(data)
|
||||
|
||||
output = model(*data)
|
||||
|
||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||
output = convert_to_fp32(output)
|
||||
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
if return_loss:
|
||||
loss = criterion(*output, *label)
|
||||
|
||||
loss /= grad_accum_size
|
||||
|
||||
if not forward_only:
|
||||
# backward
|
||||
if self.use_zero_level_2_3:
|
||||
optimizer.backward(loss)
|
||||
elif self.fp16:
|
||||
if self.amp_type == AMP_TYPE.APEX:
|
||||
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
elif self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler.scale(loss).backward()
|
||||
elif self.amp_type == AMP_TYPE.PARALLEL:
|
||||
loss = optimizer.scale_loss(loss)
|
||||
loss.backward()
|
||||
# scale back to display the original value in logs
|
||||
loss.div_(optimizer.grad_scaler.scale)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if return_loss:
|
||||
return output, label, loss * grad_accum_size
|
||||
else:
|
||||
return output, None, None
|
||||
|
||||
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
|
||||
# step optimizer
|
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||
if grad_clipping > 0.0:
|
||||
self._torch_amp_scaler.unscale_(optimizer)
|
||||
clip_grad_norm_fp32(model.parameters(), grad_clipping)
|
||||
self._torch_amp_scaler.step(optimizer)
|
||||
self._torch_amp_scaler.update()
|
||||
else:
|
||||
if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
|
||||
clip_grad_norm_fp32(model.parameters(), grad_clipping)
|
||||
optimizer.step()
|
61
colossalai/engine/schedule/_non_pipeline_schedule.py
Normal file
61
colossalai/engine/schedule/_non_pipeline_schedule.py
Normal file
@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
from colossalai.engine import Engine
|
||||
from torch.optim import Optimizer
|
||||
from ._base_schedule import BaseSchedule
|
||||
from colossalai.utils import conditional_context
|
||||
|
||||
|
||||
class NonPipelineSchedule(BaseSchedule):
|
||||
"""A helper schedule class for no pipeline parallelism running environment.
|
||||
During one process, it loads a batch of dataset and feeds it to the model.
|
||||
After getting the output and calculating the loss, it will use :meth:`step`
|
||||
to update the parameters if it is in training mode.
|
||||
:param amp_type: The type of automatic mixed precision
|
||||
:param amp_config: The configuration of automatic mixed procision
|
||||
:type amp_type: AMP_TYPE
|
||||
:type amp_config: dict
|
||||
"""
|
||||
|
||||
def forward_backward_step(self,
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True):
|
||||
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
:param engine: Model for training and inference
|
||||
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
|
||||
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
|
||||
:param return_loss: Loss will be returned if True
|
||||
:type engine: Iterator
|
||||
:type data_iter: Iterator
|
||||
:type forward_only: bool, optional
|
||||
:type return_loss: bool, optional
|
||||
:return: (output, label, loss)
|
||||
"""
|
||||
assert forward_only or return_loss, \
|
||||
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
data, label = self.load_batch(data_iter)
|
||||
|
||||
# forward
|
||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||
output = engine(*data)
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
if return_loss:
|
||||
loss = engine.criterion(*output, *label)
|
||||
|
||||
if not forward_only:
|
||||
engine.backward(loss)
|
||||
|
||||
if return_loss:
|
||||
return output, label, loss
|
||||
else:
|
||||
return output, None, None
|
@ -10,12 +10,12 @@ from torch import Tensor
|
||||
from colossalai.communication import *
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from colossalai.utils import get_current_device
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ._utils import convert_to_fp16
|
||||
from ..amp import AMP_TYPE
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
|
||||
def squeeze(x: Union[Tensor, tuple, list]):
|
||||
@ -28,32 +28,25 @@ def squeeze(x: Union[Tensor, tuple, list]):
|
||||
class PipelineSchedule(BaseSchedule):
|
||||
"""A helper schedule class for pipeline parallelism running environment.
|
||||
It uses non-interleaved 1F1B strategy. Other properties are similar as
|
||||
:class:`NoPipelineSchedule`.
|
||||
:class:`NonPipelineSchedule`.
|
||||
|
||||
:param num_microbatches: The number of microbatches
|
||||
:param amp_type: The type of automatic mixed precision
|
||||
:param amp_config: The configuration of automatic mixed procision
|
||||
:param sync_data: If set to `True`, will sync data every batch over pipeline stages
|
||||
:type num_microbatches: int
|
||||
:type amp_type: AMP_TYPE
|
||||
:type amp_config: dict
|
||||
:type sync_data: bool
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_microbatches,
|
||||
amp_type: AMP_TYPE = None,
|
||||
amp_config: dict = None):
|
||||
sync_data: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.num_microbatches = num_microbatches
|
||||
self.data_sync = True # close after making sure data is identical
|
||||
|
||||
# amp
|
||||
# LSGL: amp_config is not used, but leave here for future extension
|
||||
self.amp_type = amp_type
|
||||
self.amp_config = amp_config
|
||||
|
||||
if self.amp_type is not None:
|
||||
assert self.amp_type == AMP_TYPE.PARALLEL, 'We only support AMP_TYPE.PARALLEL for pipeline training for now'
|
||||
self.sync_data = sync_data
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, (
|
||||
@ -67,30 +60,37 @@ class PipelineSchedule(BaseSchedule):
|
||||
return data
|
||||
|
||||
def _sync_data(self):
|
||||
reqs = []
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
src_rank = gpc.get_global_rank()
|
||||
dist.broadcast(
|
||||
reqs.append(dist.broadcast(
|
||||
tensor=self.batch_data,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
)
|
||||
dist.broadcast(
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
|
||||
async_op=True
|
||||
))
|
||||
reqs.append(dist.broadcast(
|
||||
tensor=self.batch_label,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
)
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
|
||||
async_op=True
|
||||
))
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
dist.broadcast(
|
||||
reqs.append(dist.broadcast(
|
||||
tensor=self.batch_data,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
)
|
||||
dist.broadcast(
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
|
||||
async_op=True
|
||||
))
|
||||
reqs.append(dist.broadcast(
|
||||
tensor=self.batch_label,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
)
|
||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
|
||||
async_op=True
|
||||
))
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# Pipeline schedule just puts data in memory
|
||||
def load_batch(self, data_iter):
|
||||
@ -104,7 +104,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
assert batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = batch_size // self.num_microbatches
|
||||
if self.data_sync:
|
||||
if self.sync_data:
|
||||
self._sync_data()
|
||||
|
||||
def _get_data_slice(self, tensor):
|
||||
@ -116,21 +116,20 @@ class PipelineSchedule(BaseSchedule):
|
||||
self.batch_pos += self.microbatch_size
|
||||
return (data,), (label,)
|
||||
|
||||
def initialize(self, model, optimizer):
|
||||
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||
def pre_processing(self, engine):
|
||||
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||
raise TypeError(
|
||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||
)
|
||||
|
||||
# LSG: set default dtype to fp16 for communication
|
||||
if self.amp_type == AMP_TYPE.PARALLEL:
|
||||
if isinstance(engine.model, NaiveAMPModel):
|
||||
torch.set_default_dtype(torch.half)
|
||||
self.logger.info(
|
||||
self.logger.warning(
|
||||
'default tensor dtype is set to torch.half for fp16 training',
|
||||
ranks=[0])
|
||||
|
||||
def forward_step(self, model, criterion, input_tensor, return_tensors,
|
||||
grad_accum_size, return_loss=True):
|
||||
def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
@ -138,17 +137,16 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
if input_tensor is None:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
if self.amp_type == AMP_TYPE.PARALLEL:
|
||||
input_tensor = convert_to_fp16(input_tensor)
|
||||
input_tensor = squeeze(input_tensor)
|
||||
output_tensor = model(input_tensor)
|
||||
output_tensor = engine(input_tensor)
|
||||
output_tensor = squeeze(output_tensor)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if return_loss:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
loss_reduced = criterion(output_tensor, *label) \
|
||||
/ (self.num_microbatches * grad_accum_size)
|
||||
loss_reduced = engine.criterion(output_tensor, *label) \
|
||||
/ self.num_microbatches
|
||||
|
||||
return_tensors.append(
|
||||
tuple((output_tensor, label[0], loss_reduced)))
|
||||
return loss_reduced
|
||||
@ -159,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
else:
|
||||
return output_tensor
|
||||
|
||||
def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
|
||||
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
|
||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||
Returns the gradients with respect to the input tensor (None if first stage).
|
||||
@ -171,9 +169,10 @@ class PipelineSchedule(BaseSchedule):
|
||||
input_tensor.retain_grad()
|
||||
|
||||
# Backward pass.
|
||||
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
|
||||
output_tensor = optimizer.scale_loss(output_tensor)
|
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
||||
if output_tensor_grad is None:
|
||||
engine.backward(output_tensor)
|
||||
else:
|
||||
engine.backward_by_grad(output_tensor, output_tensor_grad)
|
||||
|
||||
# Collect the grad of the input_tensor.
|
||||
input_tensor_grad = None
|
||||
@ -183,12 +182,9 @@ class PipelineSchedule(BaseSchedule):
|
||||
return input_tensor_grad
|
||||
|
||||
def forward_backward_step(self,
|
||||
engine,
|
||||
data_iter,
|
||||
model,
|
||||
criterion,
|
||||
optimizer=None,
|
||||
forward_only=False,
|
||||
grad_accum_size: int = 1,
|
||||
return_loss=True):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
@ -226,9 +222,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = recv_forward(ft_shape)
|
||||
output_tensor = self.forward_step(
|
||||
model, criterion,
|
||||
input_tensor, return_tensors,
|
||||
grad_accum_size, return_loss=return_loss
|
||||
engine, input_tensor, return_tensors,
|
||||
return_loss=return_loss
|
||||
)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
bt_shape = output_tensor.shape
|
||||
@ -252,9 +247,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
|
||||
output_tensor = self.forward_step(
|
||||
model, criterion,
|
||||
input_tensor, return_tensors,
|
||||
grad_accum_size, return_loss=return_loss
|
||||
engine, input_tensor, return_tensors,
|
||||
return_loss=return_loss
|
||||
)
|
||||
if forward_only:
|
||||
send_forward(output_tensor)
|
||||
@ -276,7 +270,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
input_tensor_grad = self.backward_step(
|
||||
optimizer,
|
||||
engine,
|
||||
input_tensor, output_tensor,
|
||||
output_tensor_grad
|
||||
)
|
||||
@ -297,7 +291,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
output_tensor_grad = recv_backward(bt_shape)
|
||||
|
||||
input_tensor_grad = self.backward_step(
|
||||
optimizer,
|
||||
engine,
|
||||
input_tensor, output_tensor,
|
||||
output_tensor_grad
|
||||
)
|
||||
@ -309,11 +303,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
sum(loss) * grad_accum_size)
|
||||
sum(loss))
|
||||
else:
|
||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||
else:
|
||||
return tuple((None, None, None))
|
||||
|
||||
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
|
||||
optimizer.step()
|
@ -1,27 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Union, List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
|
||||
if isinstance(data, Tensor):
|
||||
ret = data.half()
|
||||
elif isinstance(data, (list, tuple)):
|
||||
ret = [val.half() for val in data]
|
||||
else:
|
||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||
return ret
|
||||
|
||||
|
||||
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
|
||||
if isinstance(data, Tensor):
|
||||
ret = data.float()
|
||||
elif isinstance(data, (list, tuple)):
|
||||
ret = [val.float() for val in data]
|
||||
else:
|
||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||
return ret
|
||||
|
@ -3,377 +3,326 @@
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Optional, Union
|
||||
from typing import Tuple
|
||||
|
||||
import os
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Union, Optional, Tuple, List, Dict
|
||||
|
||||
from colossalai.amp import convert_to_amp, AMP_TYPE
|
||||
from colossalai.context import Config, ParallelMode, ConfigException
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
|
||||
from colossalai.nn import DataParallelSampler
|
||||
from colossalai.nn.model.base_model import BaseModel
|
||||
from .builder import (ModelInitializer, build_dataset, build_loss,
|
||||
build_model, build_optimizer,
|
||||
build_optimizer_wrapper, build_schedule)
|
||||
from .context import Config, ParallelMode
|
||||
from .core import global_context as gpc
|
||||
from .utils import get_current_device, sync_model_param_in_dp
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device,
|
||||
sync_model_param_in_dp, is_using_ddp, is_using_pp)
|
||||
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
|
||||
def parse_args():
|
||||
def get_default_parser():
|
||||
'''Reads user command line and uses an argument parser to parse the input arguments.
|
||||
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
|
||||
|
||||
:return: call the parse arguments function
|
||||
:return: returns the parser with the default arguments, the user may add customized arguments into this parser
|
||||
:rtype: Namespace
|
||||
'''
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, help='path to the config file')
|
||||
parser.add_argument('--host',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the master address for distributed training')
|
||||
parser.add_argument('--port',
|
||||
type=str,
|
||||
default=None,
|
||||
type=int,
|
||||
help='the master port for distributed training')
|
||||
parser.add_argument('--world_size', type=int, help='world size for ')
|
||||
parser.add_argument('--world_size', type=int, help='world size for distributed training')
|
||||
parser.add_argument('--rank', type=int, help='rank for the default process group')
|
||||
parser.add_argument('--local_rank',
|
||||
type=int,
|
||||
help='rank for the default process group')
|
||||
help='local rank on the node')
|
||||
parser.add_argument('--backend',
|
||||
type=str,
|
||||
default='nccl',
|
||||
help='backend for torch.distributed')
|
||||
return parser.parse_args()
|
||||
help='backend for distributed communication')
|
||||
return parser
|
||||
|
||||
|
||||
def init_dist(config: Union[str, dict] = None,
|
||||
local_rank: int = None,
|
||||
world_size: int = None,
|
||||
host: str = None,
|
||||
port: str = None,
|
||||
backend: str = None):
|
||||
def launch(config: Union[str, Path, Config, Dict],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
local_rank: int = None,
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
|
||||
Then initialize and set distributed environment by calling global_context's functions.
|
||||
Then initialize and set distributed environment by calling global_context's functions.
|
||||
|
||||
:param config: config file or config file path are both acceptable
|
||||
:type config: Union[str, dict], optional
|
||||
:param local_rank: rank for the default process group, defaults to None
|
||||
:type config: Union[str, dict, Config]
|
||||
:param rank: rank for the default process group
|
||||
:type rank: int
|
||||
:param world_size: world size of the default process group
|
||||
:type world_size: int
|
||||
:param host: the master address for distributed training
|
||||
:type host: str
|
||||
:param port: the master port for distributed training
|
||||
:type port: str
|
||||
:param backend: backend for torch.distributed
|
||||
:type backend: str
|
||||
:param local_rank: rank for the process on the node and is used to set the default CUDA device,
|
||||
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically
|
||||
:type local_rank: int, optional
|
||||
:param world_size: world size of GPUs, defaults to None
|
||||
:type world_size: int, optional
|
||||
:param host: the master address for distributed training, defaults to None
|
||||
:type host: str, optional
|
||||
:param port: the master port for distributed training, defaults to None
|
||||
:type port: str, optional
|
||||
:param backend: backend for torch.distributed, defaults to None
|
||||
:type backend: str, optional
|
||||
:raises Exception: raise exception when config type is wrong
|
||||
'''
|
||||
args = [config, local_rank, world_size, host, port, backend]
|
||||
arg_given = [arg is not None for arg in args]
|
||||
|
||||
if not all(arg_given):
|
||||
args = parse_args()
|
||||
|
||||
if config is None:
|
||||
config = args.config
|
||||
if local_rank is None:
|
||||
local_rank = args.local_rank
|
||||
if world_size is None:
|
||||
world_size = args.world_size
|
||||
if host is None:
|
||||
host = args.host
|
||||
if port is None:
|
||||
port = args.port
|
||||
if backend is None:
|
||||
backend = args.backend
|
||||
args = Config(
|
||||
dict(config=config,
|
||||
host=host,
|
||||
port=port,
|
||||
world_size=world_size,
|
||||
local_rank=local_rank,
|
||||
backend=backend))
|
||||
|
||||
# set distributed settings
|
||||
dist_args = Config(
|
||||
dict(local_rank=args.local_rank,
|
||||
world_size=args.world_size,
|
||||
backend=args.backend))
|
||||
|
||||
gpc.set_dist_args(dist_args)
|
||||
gpc.verbose = verbose
|
||||
|
||||
# set config
|
||||
if isinstance(args.config, dict):
|
||||
cfg = args.config
|
||||
elif isinstance(args.config, (str, Path)):
|
||||
cfg = Config.from_file(args.config)
|
||||
else:
|
||||
raise Exception('Config type error: {}'.format(type(args.config)))
|
||||
gpc.load_config(cfg)
|
||||
assert isinstance(config, (Config, str, Path, dict)), \
|
||||
f'expected argument config to be Config, str or Path, but got {type(config)}'
|
||||
if not isinstance(config, Config) and isinstance(config, dict):
|
||||
config = Config(config)
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.from_file(config)
|
||||
gpc.load_config(config)
|
||||
|
||||
# init dist groups
|
||||
gpc.init_global_dist(args.host, args.port)
|
||||
# init default process group
|
||||
gpc.init_global_dist(rank, world_size, backend, host, port)
|
||||
|
||||
# init process groups for different parallel modes from config
|
||||
gpc.init_parallel_groups()
|
||||
|
||||
# init dist logger
|
||||
init_global_dist_logger()
|
||||
|
||||
# set cuda device
|
||||
if torch.cuda.is_available():
|
||||
gpc.set_device()
|
||||
# if local rank is not given, calculate automatically
|
||||
gpc.set_device(local_rank)
|
||||
|
||||
gpc.set_seed(seed)
|
||||
|
||||
if verbose:
|
||||
logger = get_dist_logger()
|
||||
logger.info(f'Distributed environment is initialized, '
|
||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
|
||||
|
||||
|
||||
def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs):
|
||||
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
||||
|
||||
.. note: when pipeline parallel is enabled, shuffle cannot be True
|
||||
as it will result in mismatch between input data on the 1st
|
||||
stage and label on the last stage
|
||||
|
||||
:param dataset: a :class:utils.data.dataset dataset
|
||||
:param seed: random worker seed, defaults to 1024
|
||||
:type seed: int, optional
|
||||
:param add_sampler_if_possible: [description], defaults to False
|
||||
:type add_sampler_if_possible: bool, optional
|
||||
:return: a :class:utils.data.dataset dataloader
|
||||
:rtype: torch.utils.data.dataset
|
||||
'''
|
||||
_kwargs = kwargs.copy()
|
||||
if 'shuffle' in _kwargs:
|
||||
shuffle = _kwargs.pop('shuffle')
|
||||
else:
|
||||
shuffle = False
|
||||
|
||||
if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
||||
sampler = DataParallelSampler(dataset, shuffle=shuffle)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
if sampler is None:
|
||||
return DataLoader(dataset,
|
||||
worker_init_fn=seed_worker,
|
||||
shuffle=shuffle,
|
||||
**_kwargs)
|
||||
else:
|
||||
return DataLoader(dataset,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
**_kwargs)
|
||||
def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
rank = int(os.environ['SLURM_PROCID'])
|
||||
world_size = int(os.environ['SLURM_NPROCS'])
|
||||
launch(config=config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def initialize(config: Union[str, dict] = None,
|
||||
local_rank: int = None,
|
||||
world_size: int = None,
|
||||
host: str = None,
|
||||
port: str = None,
|
||||
backend: str = None,
|
||||
train_dataloader: Optional[Union[Iterable, Callable]] = None,
|
||||
test_dataloader: Optional[Union[Iterable, Callable]] = None,
|
||||
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
||||
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
||||
launch(config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
launch(config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
optimizer: Union[Optimizer, List[Optimizer]],
|
||||
criterion: Union[_Loss, List[_Loss]],
|
||||
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||
lr_scheduler: _LRScheduler = None,
|
||||
verbose: bool = True
|
||||
) -> Tuple[Engine, DataLoader, DataLoader]:
|
||||
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
|
||||
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
|
||||
|
||||
:param config: config file or config file path are both acceptable
|
||||
:type config: Union[str, dict], optional
|
||||
:param local_rank: rank for the default process group, defaults to None
|
||||
:type local_rank: int, optional
|
||||
:param world_size: world size of GPUs, defaults to None
|
||||
:type world_size: int, optional
|
||||
:param host: the master address for distributed training, defaults to None
|
||||
:type host: str, optional
|
||||
:param port: the master port for distributed training, defaults to None
|
||||
:type port: str, optional
|
||||
:param backend: backend for torch.distributed, defaults to None
|
||||
:type backend: str, optional
|
||||
:param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
||||
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
|
||||
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
||||
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
|
||||
:return: (engine, train_dataloader, test_dataloader, criterion)
|
||||
:param model: your model instance
|
||||
:type model: a single or a list of ``torch.nn.Module`` objects
|
||||
:param optimizer: your optimizer instance
|
||||
:type optimizer: a single or a list of ``torch.optim.optimizer.Optimizer`` objects
|
||||
:param criterion: your criterion instance
|
||||
:type criterion: a single or a list of ``torch.nn.modules.loss._Loss`` objects
|
||||
:param train_dataloader: dataloaders for training data
|
||||
:type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
|
||||
:param train_dataloader: dataloaders for testing data
|
||||
:type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
|
||||
:return: (engine, criterion, train_dataloader, test_dataloader)
|
||||
:rtype: tuple
|
||||
'''
|
||||
# initialize distributed environment
|
||||
init_dist(config=config,
|
||||
local_rank=local_rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend)
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
gpc.verbose = verbose
|
||||
|
||||
# init logger
|
||||
logger = get_global_dist_logger()
|
||||
logger.info(f'Distributed environment is initialized, '
|
||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
|
||||
# get config from gpc
|
||||
config = gpc.config
|
||||
|
||||
# print config
|
||||
logger.info(f"\n========== Your Config ========\n"
|
||||
f"{pprint.pformat(gpc.config)}\n"
|
||||
f"================================", ranks=[0])
|
||||
if verbose:
|
||||
logger.info(f"\n========== Your Config ========\n"
|
||||
f"{pprint.pformat(gpc.config)}\n"
|
||||
f"================================\n", ranks=[0])
|
||||
|
||||
# cudnn
|
||||
cudnn_benchmark = gpc.config.get('cudnn_benchmark', True)
|
||||
cudnn_deterministic = gpc.config.get('cudnn_deterministic', False)
|
||||
cudnn_benchmark = config.get('cudnn_benchmark', True)
|
||||
cudnn_deterministic = config.get('cudnn_deterministic', False)
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.backends.cudnn.deterministic = cudnn_deterministic
|
||||
logger.info(
|
||||
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
|
||||
# set seed, cuda seed is only set when cuda is avail
|
||||
gpc.set_seed()
|
||||
|
||||
# return_items = list()
|
||||
|
||||
# check fp16 and zero
|
||||
should_convert_model_to_half = False
|
||||
should_wrap_fp16_optimizer = False
|
||||
should_wrap_zero_optimizer_level_2_3 = False
|
||||
|
||||
if hasattr(gpc.config, 'fp16'):
|
||||
fp16_mode = gpc.config.fp16.mode
|
||||
if fp16_mode == AMP_TYPE.PARALLEL:
|
||||
should_convert_model_to_half = True
|
||||
should_wrap_fp16_optimizer = True
|
||||
|
||||
if hasattr(gpc.config, 'zero'):
|
||||
should_wrap_zero_optimizer_level_2_3 = True
|
||||
zero_type = gpc.config.zero.type
|
||||
if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']:
|
||||
should_convert_model_to_half = True
|
||||
assert not should_wrap_fp16_optimizer, \
|
||||
'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3'
|
||||
|
||||
# build model
|
||||
logger.info('Building model ...', ranks=[0])
|
||||
assert hasattr(
|
||||
gpc.config, 'model'), "Build error: configuration 'model' is missing"
|
||||
if gpc.pipeline_parallel_size > 1:
|
||||
model = ModelInitializer(gpc.config.model, 1, verbose=True)
|
||||
model = model.model_initialize()
|
||||
else:
|
||||
model = build_model(gpc.config.model)
|
||||
if isinstance(model, BaseModel):
|
||||
model.build_from_cfg()
|
||||
model = model.to(get_current_device())
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
sync_model_param_in_dp(model)
|
||||
logger.info('Model is created', ranks=[0])
|
||||
|
||||
if should_convert_model_to_half:
|
||||
model = model.half()
|
||||
logger.info("Model is cast to fp16", ranks=[0])
|
||||
# check amp and zero
|
||||
fp16_cfg = gpc.config.get('fp16', None)
|
||||
zero_cfg = gpc.config.get('zero', None)
|
||||
|
||||
# training data
|
||||
if callable(train_dataloader):
|
||||
logger.info(
|
||||
f'Build train data loader from {train_dataloader}', ranks=[0])
|
||||
train_dataloader = train_dataloader()
|
||||
if train_dataloader is None and hasattr(gpc.config, 'train_data'):
|
||||
logger.info('Preparing data ...', ranks=[0])
|
||||
# assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing."
|
||||
train_dataset = build_dataset(gpc.config.train_data.dataset)
|
||||
logger.info('Train dataset is ready.', ranks=[0])
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
|
||||
raise ConfigException(
|
||||
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
||||
|
||||
train_dataloader = get_dataloader(train_dataset,
|
||||
gpc.config.get('seed', 1024),
|
||||
True,
|
||||
**gpc.config.train_data.dataloader,
|
||||
)
|
||||
logger.info(
|
||||
f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0])
|
||||
# initialize amp
|
||||
amp_mode = None
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
||||
cfg_ = fp16_cfg.copy()
|
||||
amp_mode = cfg_.pop('mode')
|
||||
model, optimizer, criterion = convert_to_amp(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
mode=amp_mode,
|
||||
amp_config=cfg_)
|
||||
|
||||
if callable(test_dataloader):
|
||||
logger.info(
|
||||
f'Build test data loader from {test_dataloader}', ranks=[0])
|
||||
test_dataloader = test_dataloader()
|
||||
# testing data, allowed to be None
|
||||
if test_dataloader is None and hasattr(gpc.config, 'test_data'):
|
||||
test_dataset = build_dataset(gpc.config.test_data.dataset)
|
||||
test_dataloader = get_dataloader(
|
||||
test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader)
|
||||
logger.info(
|
||||
f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0])
|
||||
if zero_cfg is not None:
|
||||
cfg_ = zero_cfg.copy()
|
||||
level = cfg_.pop('level')
|
||||
model, optimizer = convert_to_zero(model=model,
|
||||
optimizer=optimizer,
|
||||
level=level,
|
||||
zero_config=cfg_
|
||||
)
|
||||
|
||||
# build loss function
|
||||
assert hasattr(gpc.config, 'loss'), \
|
||||
'Build error: configuration \'loss\' is missing.'
|
||||
criterion = build_loss(gpc.config.loss)
|
||||
logger.info('Loss function is created', ranks=[0])
|
||||
|
||||
# build optimizer
|
||||
assert hasattr(gpc.config, 'optimizer'), \
|
||||
"Build error: configuration 'optimizer' is missing."
|
||||
optim_type = gpc.config.optimizer.type
|
||||
is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer'
|
||||
if is_pytorch_native_zero_level_1:
|
||||
original_cfg_copy = gpc.config.optimizer.copy()
|
||||
original_cfg_copy.pop('type')
|
||||
cfg = dict(type=optim_type, process_group=gpc.get_group(
|
||||
ParallelMode.DATA), **original_cfg_copy)
|
||||
optimizer = build_optimizer(cfg, model)
|
||||
# gradient handler
|
||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||
if gradient_handler_cfg is None:
|
||||
# if gradient handler is not specified in the configuration file,
|
||||
# check in the following order
|
||||
# 1. if optimizer is ZERO, then use zero grad handler
|
||||
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
||||
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
||||
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
|
||||
elif is_using_ddp():
|
||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
else:
|
||||
optimizer = build_optimizer(gpc.config.optimizer, model)
|
||||
if not isinstance(gradient_handler_cfg, list):
|
||||
raise ConfigException(
|
||||
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
|
||||
|
||||
if should_wrap_zero_optimizer_level_2_3:
|
||||
optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model)
|
||||
|
||||
if should_wrap_fp16_optimizer:
|
||||
# replace the field mode with type
|
||||
fp16_cfg = gpc.config.fp16.copy()
|
||||
amp_type = fp16_cfg.pop('mode')
|
||||
assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL'
|
||||
fp16_cfg['type'] = 'FP16Optimizer'
|
||||
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
|
||||
logger.info('Optimizer is created', ranks=[0])
|
||||
|
||||
# build schedule and engine
|
||||
if hasattr(gpc.config, 'fp16'):
|
||||
amp_type = gpc.config.fp16.mode
|
||||
amp_cfg = gpc.config.fp16.copy()
|
||||
amp_cfg.pop('mode')
|
||||
if gradient_handler_cfg is None:
|
||||
gradient_handlers = None
|
||||
if verbose and not isinstance(model, DDP):
|
||||
logger.warning(
|
||||
"No PyTorch DDP or gradient handler is set up, please make sure you do not need "
|
||||
"to all-reduce the gradients after a training step.",
|
||||
ranks=[0])
|
||||
else:
|
||||
amp_type = None
|
||||
amp_cfg = None
|
||||
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
||||
|
||||
engine_cfg = gpc.config.get('engine', dict())
|
||||
schedule_cfg = engine_cfg.pop('schedule', None)
|
||||
# check if optimizer is ColossalaiOptimizer
|
||||
if not isinstance(optimizer, (ColossalaiOptimizer, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||
optimizer = ColossalaiOptimizer(optim=optimizer)
|
||||
|
||||
schedule_type = None
|
||||
if schedule_cfg is not None:
|
||||
schedule_type = schedule_cfg.get('type', None)
|
||||
# gradient accumulation
|
||||
grad_accum_size = gpc.config.get('gradient_accumulation', None)
|
||||
if grad_accum_size is not None:
|
||||
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
accumulate_size=grad_accum_size,
|
||||
gradient_handlers=gradient_handlers,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
if schedule_type is not None:
|
||||
# run customized schedule
|
||||
schedule_cfg['amp_type'] = amp_type
|
||||
schedule_cfg['amp_config'] = amp_cfg
|
||||
schedule = build_schedule(schedule_cfg)
|
||||
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
assert schedule_cfg is not None, \
|
||||
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
|
||||
schedule = PipelineSchedule(
|
||||
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
|
||||
else:
|
||||
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
|
||||
# clip grad norm
|
||||
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
||||
if clip_grad_norm > 0:
|
||||
if zero_cfg is not None:
|
||||
raise ConfigException(
|
||||
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
|
||||
elif fp16_cfg is not None and fp16_cfg.mode == AMP_TYPE.NAIVE:
|
||||
raise ConfigException(
|
||||
"clip_grad_norm should be specified with AMP_TYPE.NAIVE, you should specify clip_grad in fp16 configuration")
|
||||
|
||||
engine = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
step_schedule=schedule,
|
||||
**gpc.config.get('engine', dict())
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm
|
||||
)
|
||||
|
||||
return engine, train_dataloader, test_dataloader
|
||||
return engine, train_dataloader, test_dataloader, lr_scheduler
|
||||
|
@ -1,26 +1,10 @@
|
||||
from colossalai.core import global_context as gpc
|
||||
from .logging import DistributedLogger
|
||||
|
||||
__all__ = ['get_global_dist_logger', 'get_dist_logger', 'DistributedLogger', 'init_global_dist_logger']
|
||||
|
||||
_GLOBAL_LOGGER: DistributedLogger = None
|
||||
__all__ = ['get_dist_logger', 'DistributedLogger']
|
||||
|
||||
|
||||
def get_dist_logger(name, level='INFO', root_path: str = None, mode='a'):
|
||||
return DistributedLogger(name=name, level=level, root_path=root_path, mode=mode)
|
||||
|
||||
|
||||
def get_global_dist_logger():
|
||||
assert _GLOBAL_LOGGER is not None, 'Global distributed logger is not initialized'
|
||||
return _GLOBAL_LOGGER
|
||||
|
||||
|
||||
def init_global_dist_logger():
|
||||
rank = gpc.get_global_rank()
|
||||
if hasattr(gpc.config, 'logging'):
|
||||
logger = get_dist_logger(name=f'rank_{rank}', **gpc.config.logging)
|
||||
else:
|
||||
logger = get_dist_logger(name=f'rank_{rank}', level='INFO')
|
||||
global _GLOBAL_LOGGER
|
||||
assert _GLOBAL_LOGGER is None, 'Global distributed logger has already been initialized'
|
||||
_GLOBAL_LOGGER = logger
|
||||
def get_dist_logger(name='root'):
|
||||
"""Get logger instance based on name. The DistributedLogger will create singleton instances,
|
||||
which means that only one logger instance is created per name.
|
||||
"""
|
||||
return DistributedLogger.get_instance(name=name)
|
||||
|
@ -1,11 +1,13 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import colossalai
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
|
||||
logging.basicConfig(level=logging.INFO, format=_FORMAT)
|
||||
@ -16,40 +18,92 @@ class DistributedLogger:
|
||||
|
||||
:param name: The name of the logger
|
||||
:type name: str
|
||||
:param level: The threshold for the logger. Logging messages which are less severe than `level`
|
||||
will be ignored
|
||||
:type level: str
|
||||
:param root_path: The root path where logs are stored
|
||||
:type root_path: str, optional
|
||||
:param mode: The mode that the file is opened in. Defaults to 'a'
|
||||
:type mode: str, optional
|
||||
"""
|
||||
|
||||
def __init__(self, name, level='INFO', root_path: str = None, mode='a'):
|
||||
self._logger = logging.getLogger(name)
|
||||
__instances = dict()
|
||||
|
||||
@staticmethod
|
||||
def get_instance(name: str):
|
||||
"""Get the unique single logger instance based on name.
|
||||
:param name: The name of the logger
|
||||
:type name: str
|
||||
:return: a DistributedLogger object
|
||||
:rtype: DistributedLogger
|
||||
"""
|
||||
if name in DistributedLogger.__instances:
|
||||
return DistributedLogger.__instances[name]
|
||||
else:
|
||||
logger = DistributedLogger(name=name)
|
||||
return logger
|
||||
|
||||
def __init__(self, name):
|
||||
if name in DistributedLogger.__instances:
|
||||
raise Exception('Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
|
||||
else:
|
||||
self._name = name
|
||||
self._logger = logging.getLogger(name)
|
||||
DistributedLogger.__instances[name] = self
|
||||
|
||||
@staticmethod
|
||||
def _check_valid_logging_level(level: str):
|
||||
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
|
||||
|
||||
def set_level(self, level: str):
|
||||
"""Set the logging level
|
||||
:param level: can only be INFO, DEBUG, WARNING and ERROR
|
||||
:type level: str
|
||||
"""
|
||||
self._check_valid_logging_level(level)
|
||||
self._logger.setLevel(getattr(logging, level))
|
||||
|
||||
if root_path is not None:
|
||||
log_root_path = Path(root_path)
|
||||
# create path if not exists
|
||||
log_root_path.mkdir(parents=True, exist_ok=True)
|
||||
log_path = log_root_path.joinpath(f'{name}.log')
|
||||
file_handler = logging.FileHandler(log_path, mode)
|
||||
file_handler.setLevel(getattr(logging, level))
|
||||
formatter = logging.Formatter(_FORMAT)
|
||||
file_handler.setFormatter(formatter)
|
||||
self._logger.addHandler(file_handler)
|
||||
def log_to_file(self,
|
||||
path: Union[str, Path],
|
||||
mode: str = 'a',
|
||||
level: str = 'INFO',
|
||||
suffix: str = None):
|
||||
"""Save the logs to file
|
||||
:param path: the file to save the log
|
||||
:type path: a string or pathlib.Path object
|
||||
:param mode: the mode to write log into the file
|
||||
:type mode: str
|
||||
:param level: can only be INFO, DEBUG, WARNING and ERROR
|
||||
:type level: str
|
||||
"""
|
||||
assert isinstance(path, (str, Path)), \
|
||||
f'expected argument path to be type str or Path, but got {type(path)}'
|
||||
self._check_valid_logging_level(level)
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
# set the default file name if path is a directory
|
||||
if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
|
||||
rank = 0
|
||||
else:
|
||||
rank = colossalai.core.global_context.get_global_rank()
|
||||
|
||||
if suffix is not None:
|
||||
log_file_name = f'rank_{rank}_{suffix}.log'
|
||||
else:
|
||||
log_file_name = f'rank_{rank}.log'
|
||||
path = path.joinpath(log_file_name)
|
||||
|
||||
# add file handler
|
||||
file_handler = logging.FileHandler(path, mode)
|
||||
file_handler.setLevel(getattr(logging, level))
|
||||
formatter = logging.Formatter(_FORMAT)
|
||||
file_handler.setFormatter(formatter)
|
||||
self._logger.addHandler(file_handler)
|
||||
|
||||
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
if ranks is None:
|
||||
getattr(self._logger, level)(message)
|
||||
else:
|
||||
local_rank = gpc.get_local_rank(parallel_mode)
|
||||
local_rank = colossalai.core.global_context.get_local_rank(parallel_mode)
|
||||
if local_rank in ranks:
|
||||
getattr(self._logger, level)(message)
|
||||
|
||||
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
"""Stores an info log message.
|
||||
"""Log an info message.
|
||||
|
||||
:param message:
|
||||
:type message:
|
||||
@ -61,7 +115,7 @@ class DistributedLogger:
|
||||
self._log('info', message, parallel_mode, ranks)
|
||||
|
||||
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
"""Stores a warning log message.
|
||||
"""Log a warning message.
|
||||
|
||||
:param message: The message to be logged
|
||||
:type message: str
|
||||
@ -73,7 +127,7 @@ class DistributedLogger:
|
||||
self._log('warning', message, parallel_mode, ranks)
|
||||
|
||||
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
"""Stores a debug log message.
|
||||
"""Log a debug message.
|
||||
|
||||
:param message: The message to be logged
|
||||
:type message: str
|
||||
@ -85,7 +139,7 @@ class DistributedLogger:
|
||||
self._log('debug', message, parallel_mode, ranks)
|
||||
|
||||
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
"""Stores an error log message.
|
||||
"""Log an error message.
|
||||
|
||||
:param message: The message to be logged
|
||||
:type message: str
|
||||
|
@ -1,4 +1,3 @@
|
||||
from .data import *
|
||||
from .layer import *
|
||||
from .loss import *
|
||||
from .lr_scheduler import *
|
||||
|
@ -1,3 +0,0 @@
|
||||
from .caltech101_dataset import Caltech101Dataset
|
||||
from .cifar10_dataset import CIFAR10Dataset
|
||||
from .sampler import *
|
@ -1,14 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pil_img_to_numpy(pil_img):
|
||||
"""convert a PIL image to numpy nd-array
|
||||
|
||||
:param pil_img: a PIL image
|
||||
:type pil_img: PIL.Image
|
||||
:return: a nd-array
|
||||
:rtype: numpy.ndarray
|
||||
"""
|
||||
np_img = np.array(pil_img)
|
||||
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
||||
return np_img
|
@ -1,17 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
from colossalai.builder import build_transform
|
||||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
|
||||
def __init__(self, transform_pipeline: list):
|
||||
transform_list = [build_transform(cfg) for cfg in transform_pipeline]
|
||||
transform = transforms.Compose(transform_list)
|
||||
self._transform_pipeline = transform
|
@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.distributed as dist
|
||||
from torchvision.datasets import Caltech101
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
class Caltech101Dataset(BaseDataset):
|
||||
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
|
||||
|
||||
:param transform_pipeline: A list of functions' config, which takes in an PIL image
|
||||
and returns a transformed version
|
||||
:type transform_pipeline: list
|
||||
"""
|
||||
|
||||
def __init__(self, transform_pipeline: list, *args, **kwargs):
|
||||
super().__init__(transform_pipeline)
|
||||
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
|
||||
dist.barrier()
|
||||
self._dataset = Caltech101(
|
||||
transform=self._transform_pipeline, *args, **kwargs)
|
||||
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
|
||||
dist.barrier()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""
|
||||
|
||||
:param item: Index
|
||||
:type item: int
|
||||
:return: ((image,), (target,)) where the type of target specified by target_type.
|
||||
:rtype: tuple
|
||||
"""
|
||||
img, label = self._dataset.__getitem__(item)
|
||||
return (img,), (label,)
|
@ -1,44 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.distributed as dist
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
class CIFAR10Dataset(BaseDataset):
|
||||
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||
|
||||
:param transform_pipeline: A list of functions' config, which takes in an PIL image
|
||||
and returns a transformed version
|
||||
:type transform_pipeline: list
|
||||
"""
|
||||
|
||||
def __init__(self, transform_pipeline: list, *args, **kwargs):
|
||||
super().__init__(transform_pipeline)
|
||||
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
|
||||
dist.barrier()
|
||||
self._dataset = CIFAR10(transform=self._transform_pipeline,
|
||||
*args,
|
||||
**kwargs)
|
||||
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
|
||||
dist.barrier()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""
|
||||
|
||||
:param item: Index
|
||||
:type item: int
|
||||
:return: ((image,), (target,)) where the type of target specified by target_type.
|
||||
:rtype: tuple
|
||||
"""
|
||||
img, label = self._dataset.__getitem__(item)
|
||||
return (img,), (label,)
|
@ -1,4 +0,0 @@
|
||||
from .base_sampler import BaseSampler
|
||||
from .data_parallel_sampler import DataParallelSampler
|
||||
|
||||
__all__ = ['BaseSampler', 'DataParallelSampler']
|
33
colossalai/nn/init.py
Normal file
33
colossalai/nn/init.py
Normal file
@ -0,0 +1,33 @@
|
||||
import math
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import init as init
|
||||
|
||||
|
||||
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
|
||||
if init_method == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(tensor, -bound, bound)
|
||||
elif init_method == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(tensor, -a, a)
|
||||
elif init_method == 'jax_embed':
|
||||
std = math.sqrt(1.0 / fan_in)
|
||||
init.trunc_normal_(tensor, std=std / .87962566103423978)
|
||||
elif init_method == 'zero':
|
||||
init.zeros_(tensor)
|
||||
|
||||
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
|
||||
if init_method == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(tensor, -bound, bound)
|
||||
elif init_method == 'jax':
|
||||
init.normal_(tensor, std=1e-6)
|
||||
elif init_method == 'jax_embed':
|
||||
init.trunc_normal_(tensor, std=.02)
|
||||
elif init_method == 'zero':
|
||||
init.zeros_(tensor)
|
@ -1,9 +1,8 @@
|
||||
from .fused_bias_gelu import bias_gelu_impl
|
||||
from .parallel_1d import *
|
||||
from .parallel_2d import *
|
||||
from .parallel_2p5d import *
|
||||
from .parallel_3d import *
|
||||
from .parallel_sequence import *
|
||||
from .parallel_vision_transformer import *
|
||||
from .vanilla_resnet import *
|
||||
from .vanilla_vision_transformer import *
|
||||
from .non_parallel_layers import *
|
||||
from .wrapper import *
|
||||
|
@ -2,40 +2,14 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
|
||||
import collections.abc
|
||||
from itertools import repeat
|
||||
import numpy as np
|
||||
from colossalai.utils.common import print_rank_0
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
|
||||
from colossalai.utils import checkpoint
|
||||
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
""" only allow exact division """
|
||||
assert numerator % denominator == 0, \
|
||||
'{} is not divisible by {}'.format(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def gelu(x: Tensor) -> Tensor:
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
def set_tensor_parallel_attribute(param):
|
||||
if not hasattr(param, IS_TENSOR_PARALLEL):
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class CheckpointModule(nn.Module):
|
||||
@ -44,15 +18,15 @@ class CheckpointModule(nn.Module):
|
||||
self.checkpoint = checkpoint
|
||||
self._use_checkpoint = checkpoint
|
||||
|
||||
def _forward(self, *args):
|
||||
def _forward(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'CheckpointModule should implement _forward method instead of origin forward')
|
||||
|
||||
def forward(self, *args):
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._use_checkpoint:
|
||||
return checkpoint(self._forward, *args)
|
||||
return checkpoint(self._forward, *args, **kwargs)
|
||||
else:
|
||||
return self._forward(*args)
|
||||
return self._forward(*args, **kwargs)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
self._use_checkpoint = self.checkpoint
|
||||
@ -61,3 +35,38 @@ class CheckpointModule(nn.Module):
|
||||
def eval(self):
|
||||
self._use_checkpoint = False
|
||||
return super().eval()
|
||||
|
||||
def divide(numerator, denominator):
|
||||
""" only allow exact division """
|
||||
assert numerator % denominator == 0, \
|
||||
'{} is not divisible by {}'.format(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
def set_tensor_parallel_attribute_by_size(param, size):
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))
|
||||
|
||||
|
||||
def set_tensor_parallel_attribute_by_partition(param, num_partitions):
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
setattr(param, NUM_PARTITIONS, num_partitions)
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
35
colossalai/nn/layer/fused_bias_gelu.py
Normal file
35
colossalai/nn/layer/fused_bias_gelu.py
Normal file
@ -0,0 +1,35 @@
|
||||
# adapted from Megatron-LM
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
|
||||
|
||||
import torch
|
||||
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def bias_gelu_back(g, bias, y):
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return ff*g
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(bias, input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_back(grad_output, bias, input)
|
||||
return tmp, tmp
|
||||
|
||||
bias_gelu_impl = GeLUFunction.apply
|
8
colossalai/nn/layer/non_parallel_layers/__init__.py
Normal file
8
colossalai/nn/layer/non_parallel_layers/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from ._vit import (ViTBlock, VanillaViTAttention, VanillaViTBlock, VanillaViTDropPath,
|
||||
VanillaViTHead, VanillaViTMLP, VanillaViTPatchEmbedding)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ViTBlock', 'VanillaViTAttention', 'VanillaViTBlock', 'VanillaViTDropPath',
|
||||
'VanillaViTHead', 'VanillaViTMLP', 'VanillaViTPatchEmbedding'
|
||||
]
|
@ -1,23 +1,47 @@
|
||||
import collections.abc
|
||||
from itertools import repeat
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from colossalai.builder import build_layer
|
||||
from colossalai.registry import LAYERS
|
||||
from .._common_utils import to_2tuple
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
@LAYERS.register_module
|
||||
class ViTBlock(nn.Module):
|
||||
"""Vision Transformer block
|
||||
|
||||
return parse
|
||||
:param attention_cfg: config of attention layer
|
||||
:type attention_cfg: dict
|
||||
:param droppath_cfg: config of drop path
|
||||
:type droppath_cfg: dict
|
||||
:param mlp_cfg: config of MLP layer
|
||||
:type mlp_cfg: dict
|
||||
:param norm_cfg: config of normlization layer
|
||||
:type norm_cfg: dict
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
attention_cfg: dict,
|
||||
droppath_cfg: dict,
|
||||
mlp_cfg: dict,
|
||||
norm_cfg: dict,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = build_layer(norm_cfg)
|
||||
self.attn = build_layer(attention_cfg)
|
||||
self.drop_path = build_layer(
|
||||
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
|
||||
self.norm2 = build_layer(norm_cfg)
|
||||
self.mlp = build_layer(mlp_cfg)
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
@ -1,5 +1,11 @@
|
||||
from .layers import Linear1D_Col, Linear1D_Row
|
||||
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
|
||||
from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D
|
||||
from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHead
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Linear1D_Col', 'Linear1D_Row',
|
||||
'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D',
|
||||
'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead'
|
||||
]
|
||||
|
34
colossalai/nn/layer/parallel_1d/_operation.py
Normal file
34
colossalai/nn/layer/parallel_1d/_operation.py
Normal file
@ -0,0 +1,34 @@
|
||||
import torch
|
||||
|
||||
try:
|
||||
import fused_mix_prec_layer_norm_cuda
|
||||
except:
|
||||
fused_mix_prec_layer_norm_cuda = None
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.eps = eps
|
||||
input_ = input.contiguous()
|
||||
weight_ = weight.contiguous()
|
||||
bias_ = bias.contiguous()
|
||||
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
|
||||
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
|
||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||
return output
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
= fused_mix_prec_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
220
colossalai/nn/layer/parallel_1d/_transformer.py
Normal file
220
colossalai/nn/layer/parallel_1d/_transformer.py
Normal file
@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
import math
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Tuple
|
||||
|
||||
from colossalai.context import seed, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from .._common_utils import divide, ACT2FN
|
||||
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
|
||||
split_forward_gather_backward
|
||||
from ..base_layer import ParallelLayer
|
||||
from .layers import Linear1D_Col, Linear1D_Row
|
||||
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerMLP1D(ParallelLayer):
|
||||
"""MLP.
|
||||
MLP will take the input with h hidden state, project it to 4*h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
mlp_ratio: int = 4.0,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
skip_bias_add: bool = False
|
||||
):
|
||||
super(TransformerMLP1D, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.skip_bias_add = skip_bias_add
|
||||
# Project to h * mlp_ratio.
|
||||
self.dense_1 = Linear1D_Col(
|
||||
self.in_features,
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
bias=not skip_bias_add,
|
||||
dtype=dtype,
|
||||
gather_output = False,
|
||||
)
|
||||
|
||||
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
|
||||
f'activation function can only be {list(ACT2FN.keys())}'
|
||||
self.activation_func = ACT2FN[act_func]
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear1D_Row(
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
self.in_features,
|
||||
bias=not skip_bias_add,
|
||||
dtype=dtype,
|
||||
parallel_input = True,
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
# self.layernorm = LayerNorm1D(in_features, dtype=dtype)
|
||||
self.layernorm = nn.LayerNorm(in_features, dtype=dtype)
|
||||
def forward(self, x):
|
||||
if self.skip_bias_add:
|
||||
intermediate_output, _ = self.dense_1(x)
|
||||
else:
|
||||
intermediate_output = self.dense_1(x)
|
||||
|
||||
intermediate_output = self.activation_func(intermediate_output)
|
||||
|
||||
if self.skip_bias_add:
|
||||
output, _ = self.dense_2(intermediate_output)
|
||||
else:
|
||||
output = self.dense_2(intermediate_output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
output = self.layernorm(x + output)
|
||||
return output
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerSelfAttention1D(ParallelLayer):
|
||||
"""Self attention layer for 1D parallel Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param attention_dropout_prob: dropout probability for attention layer
|
||||
:type attention_dropout_prob: float
|
||||
:param hidden_dropout_prob: dropout probability for hidden layer
|
||||
:type hidden_dropout_prob: float
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size)
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
|
||||
|
||||
self.query_key_value = Linear1D_Col(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear1D_Row(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
parallel_input=True,
|
||||
)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
# need to re-enable torch grad to enable fused optimization.
|
||||
# self.layernorm = LayerNorm1D(
|
||||
# hidden_size,
|
||||
# dtype=dtype)
|
||||
self.layernorm = nn.LayerNorm(
|
||||
hidden_size,
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(
|
||||
query_key_value, 3, dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(
|
||||
query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[
|
||||
:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
attention_output = self.layernorm(hidden_states + output)
|
||||
|
||||
return attention_output
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerLayer1D(ParallelLayer):
|
||||
"""Transformer layer which contains a self-attention layer and a MLP layer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param act_func: activation function, defaults to 'gelu'
|
||||
:type act_func: str, optional
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
|
||||
:type mlp_ratio: float, optional
|
||||
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
|
||||
:type attention_dropout_prob: float, optional
|
||||
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
|
||||
:type hidden_dropout_prob: float, optional
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.,
|
||||
hidden_dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = TransformerSelfAttention1D(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.mlp = TransformerMLP1D(
|
||||
in_features=hidden_size,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
output = self.mlp(attention_output)
|
||||
return output
|
@ -13,3 +13,6 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
|
||||
|
||||
|
||||
|
||||
|
411
colossalai/nn/layer/parallel_1d/_vit.py
Normal file
411
colossalai/nn/layer/parallel_1d/_vit.py
Normal file
@ -0,0 +1,411 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from colossalai import context
|
||||
|
||||
import torch
|
||||
from torch import nn as nn, Tensor, distributed as dist
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from colossalai.context import seed, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer._common_utils import divide, ACT2FN
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.utils import get_current_device
|
||||
from .layers import Linear1D_Col, Linear1D_Row
|
||||
from ..base_layer import ParallelLayer
|
||||
from .._common_utils import to_2tuple
|
||||
from ..fused_bias_gelu import bias_gelu_impl
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTMLP1D(ParallelLayer):
|
||||
"""MLP layer for 1D parallel Vision Transformer
|
||||
|
||||
:param in_features: size of each input sample
|
||||
:type in_features: int
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim
|
||||
:type mlp_ratio: int
|
||||
:param act_func: activation function, defaults to 'gelu'
|
||||
:type act_func: str, optional
|
||||
:param dropout_prob: dropout probability, defaults to 0.
|
||||
:type dropout_prob: float, optional
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param checkpoint: whether to checkpoint the layer, defaults to False
|
||||
:type checkpoint: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
mlp_ratio: int,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.checkpoint = checkpoint
|
||||
self.skip_bias_add = skip_bias_add
|
||||
assert weight_init in ('torch', 'jax')
|
||||
|
||||
if act_func == 'fused_gelu':
|
||||
self.act = bias_gelu_impl
|
||||
skip_dense_1_add_bias = True
|
||||
else:
|
||||
self.act = ACT2FN[act_func]
|
||||
skip_dense_1_add_bias = False
|
||||
|
||||
# Project to mlp_ratio * h.
|
||||
self.dense_1 = Linear1D_Col(
|
||||
self.in_features,
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
dtype=dtype,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_dense_1_add_bias,
|
||||
init_weight=weight_init,
|
||||
init_bias=weight_init
|
||||
)
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear1D_Row(
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
self.in_features,
|
||||
dtype=dtype,
|
||||
parallel_input=True,
|
||||
init_weight=weight_init, init_bias=weight_init
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.act == bias_gelu_impl:
|
||||
intermediate_output, bias = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output, bias)
|
||||
else:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
intermediate_output = self.dropout(intermediate_output)
|
||||
output = self.dense_2(intermediate_output)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTSelfAttention1D(ParallelLayer):
|
||||
"""Self-attention layer for 1D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param attention_dropout_prob: dropout probability for attention layers
|
||||
:type attention_dropout_prob: float
|
||||
:param hidden_dropout_prob: dropout probability for hidden layers
|
||||
:type hidden_dropout_prob: float
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param checkpoint: whether to checkpoint the layer, defaults to False
|
||||
:type checkpoint: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
|
||||
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
|
||||
|
||||
self.checkpoint = checkpoint
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
init_bias = 'zero'
|
||||
else:
|
||||
init_bias = weight_init
|
||||
|
||||
self.query_key_value = Linear1D_Col(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init,
|
||||
init_bias=init_bias
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear1D_Row(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
parallel_input=True,
|
||||
init_weight=weight_init, init_bias=init_bias
|
||||
)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(
|
||||
query_key_value, 3, dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(
|
||||
query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
new_context_layer_shape = context_layer.size()[
|
||||
:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
output = self.dense(context_layer)
|
||||
output = self.dropout(output)
|
||||
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead1D(ParallelLayer):
|
||||
"""Output layer for 1D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_classes: number of classes
|
||||
:type num_classes: int
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=None,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
init_weight = 'zero'
|
||||
init_bias = 'zero'
|
||||
else:
|
||||
init_weight = weight_init
|
||||
init_bias = weight_init
|
||||
|
||||
self.linear = Linear1D_Col(
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=dtype,
|
||||
gather_output=True,
|
||||
init_weight=init_weight,
|
||||
init_bias=init_bias
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = x[:, 0]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead(ParallelLayer):
|
||||
"""Output layer for 1D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_classes: number of classes
|
||||
:type num_classes: int
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=dtype
|
||||
)
|
||||
self._broadcast_linear_params()
|
||||
|
||||
def _broadcast_linear_params(self) -> None:
|
||||
self.to(get_current_device())
|
||||
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
|
||||
|
||||
dist.broadcast(self.linear.weight, src=ranks[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
dist.broadcast(self.linear.bias, src=ranks[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = x[:, 0]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTPatchEmbedding1D(ParallelLayer):
|
||||
""" 2D Image to Patch Embedding
|
||||
|
||||
:param img_size: iamge size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param embed_dim: dimension of embedding
|
||||
:type embed_dim: int
|
||||
:param in_chans: number of channels of input image, defaults to 3
|
||||
:type in_chans: int, optional
|
||||
:param flatten: whether to flatten output tensor, defaults to True
|
||||
:type flatten: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size,
|
||||
patch_size,
|
||||
embed_dim,
|
||||
in_chans=3,
|
||||
flatten=True,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv2d(in_chans,
|
||||
self.embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size
|
||||
)
|
||||
|
||||
if weight_init == 'jax':
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
|
||||
std = math.sqrt(1.0 / fan_in)
|
||||
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
# sync
|
||||
self._broadcast_conv_params()
|
||||
|
||||
def _broadcast_conv_params(self) -> None:
|
||||
self.to(get_current_device())
|
||||
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
|
||||
|
||||
dist.broadcast(self.proj.weight, src=ranks[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
dist.broadcast(self.proj.bias, src=ranks[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTTokenFuser1D(ParallelLayer):
|
||||
"""
|
||||
Fuse cls token and pos embedding to the input
|
||||
|
||||
:param img_size: image size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param embed_dim: dimension of embedding
|
||||
:type embed_dim: int
|
||||
:param drop_rate: dropout probability, defaults to 0.
|
||||
:type drop_rate: float, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size,
|
||||
patch_size,
|
||||
embed_dim,
|
||||
drop_rate=0.
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(
|
||||
1, 1, self.embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.empty(
|
||||
1, self.num_patches + 1, self.embed_dim))
|
||||
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
# move to cuda before broadcast
|
||||
self.to(get_current_device())
|
||||
dist.broadcast(self.pos_embed,
|
||||
src=gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
return x.contiguous()
|
@ -1,24 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import numbers
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Tuple
|
||||
import importlib
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context import seed, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from .._common_utils import divide
|
||||
from ._operation import FusedLayerNormAffineFunction1D
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
|
||||
split_forward_gather_backward
|
||||
from ..base_layer import ParallelLayer
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Linear1D_Col(ParallelLayer):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
@ -44,23 +50,29 @@ class Linear1D_Col(ParallelLayer):
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False):
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
init_weight='torch',
|
||||
init_bias='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = in_features
|
||||
self.output_size = output_size
|
||||
self.in_features = in_features
|
||||
self.out_features = output_size
|
||||
self.gather_output = gather_output
|
||||
self.skip_bias_add = not bias
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
self.output_size_per_partition = divide(output_size, world_size)
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size_per_partition, self.input_size,
|
||||
self.output_size_per_partition, self.in_features,
|
||||
**factory_kwargs))
|
||||
|
||||
if bias:
|
||||
@ -72,6 +84,45 @@ class Linear1D_Col(ParallelLayer):
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(init_weight, init_bias)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def reset_parameters(self, init_weight, init_bias) -> None:
|
||||
assert init_weight in ('torch', 'jax', 'zero')
|
||||
assert init_bias in ('torch', 'jax', 'zero')
|
||||
# setting
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
||||
# init weight
|
||||
if init_weight == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -bound, bound)
|
||||
elif init_weight == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -a, a)
|
||||
elif init_weight == 'zero':
|
||||
init.zeros_(self.weight)
|
||||
|
||||
# init bias
|
||||
if self.bias is not None:
|
||||
if init_bias == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
elif init_bias == 'jax':
|
||||
init.normal_(self.bias, std=1e-6)
|
||||
elif init_bias == 'zero':
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
# Set up backprop all-reduce.
|
||||
@ -104,7 +155,7 @@ class Linear1D_Row(ParallelLayer):
|
||||
:type bias: bool, optional
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False
|
||||
:param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
|
||||
:type parallel_input: bool, optional
|
||||
"""
|
||||
|
||||
@ -113,7 +164,10 @@ class Linear1D_Row(ParallelLayer):
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
parallel_input: bool = False
|
||||
parallel_input: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
init_weight='torch',
|
||||
init_bias='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -121,11 +175,13 @@ class Linear1D_Row(ParallelLayer):
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.parallel_input = parallel_input
|
||||
self.skip_bias_add = not bias
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
self.input_size_per_partition = divide(in_features, world_size)
|
||||
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
@ -146,9 +202,46 @@ class Linear1D_Row(ParallelLayer):
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(init_weight, init_bias)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.xavier_normal_(self.weight)
|
||||
def reset_parameters(self, init_weight, init_bias) -> None:
|
||||
assert init_weight in ('torch', 'jax', 'zero')
|
||||
assert init_bias in ('torch', 'jax', 'zero')
|
||||
# setting
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
||||
# init weight
|
||||
if init_weight == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -bound, bound)
|
||||
elif init_weight == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -a, a)
|
||||
elif init_weight == 'zero':
|
||||
init.zeros_(self.weight)
|
||||
|
||||
# init bias
|
||||
if self.bias is not None:
|
||||
if init_bias == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
elif init_bias == 'jax':
|
||||
init.normal_(self.bias, std=1e-6)
|
||||
elif init_bias == 'zero':
|
||||
init.zeros_(self.bias)
|
||||
dist.broadcast(self.bias,
|
||||
src=gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
@ -163,4 +256,29 @@ class Linear1D_Row(ParallelLayer):
|
||||
|
||||
if not self.skip_bias_add:
|
||||
output = output + self.bias
|
||||
return output
|
||||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class MixedFusedLayerNorm1D(torch.nn.Module):
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-5):
|
||||
super(MixedFusedLayerNorm1D, self).__init__()
|
||||
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
self.eps = eps
|
||||
self.weight = Parameter(torch.Tensor(*normalized_shape))
|
||||
self.bias = Parameter(torch.Tensor(*normalized_shape))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, input):
|
||||
return FusedLayerNormAffineFunction1D.apply(
|
||||
input, self.weight, self.bias, self.normalized_shape, self.eps)
|
||||
|
@ -20,7 +20,6 @@ def matmul_2d(a,
|
||||
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
|
||||
):
|
||||
"""Matrix multiplication for 2D parallelism
|
||||
|
||||
:param a: matrix :math:`A`
|
||||
:type a: torch.tensor
|
||||
:param b: matrix :math:`B`
|
||||
@ -86,25 +85,30 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
A_shape = A.shape
|
||||
A = A.reshape((-1, A_shape[-1]))
|
||||
A = A.reshape((-1, A_shape[-1])).contiguous()
|
||||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
B = B.reshape((-1, B_shape[-1])).contiguous()
|
||||
C_shape = (A.shape[0], B.shape[-1])
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
for i in range(summa_dim):
|
||||
A_temp = A.clone()
|
||||
B_temp = B.clone()
|
||||
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(A_temp, src=src_a,
|
||||
group=gpc.get_group(row_parallel_mode))
|
||||
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(B_temp, src=src_b,
|
||||
group=gpc.get_group(col_parallel_mode))
|
||||
torch.addmm(C, A_temp, B_temp, out=C)
|
||||
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
|
||||
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
|
||||
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
|
||||
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
|
||||
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
|
||||
op_a.wait()
|
||||
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
|
||||
for op in [op_a, op_b]:
|
||||
op.wait()
|
||||
|
||||
for i in range(summa_dim):
|
||||
src_a = i + summa_dim * row_rank
|
||||
src_b = i + summa_dim * col_rank
|
||||
src_a = src_a % summa_dim
|
||||
src_b = src_b % summa_dim
|
||||
A_temp = A_list[src_a]
|
||||
B_temp = B_list[src_b]
|
||||
torch.addmm(C, A_temp, B_temp, out=C)
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
@ -499,36 +503,61 @@ class _LayerNorm_2D(torch.autograd.Function):
|
||||
# return input_grad, None, None, None, None, None
|
||||
|
||||
|
||||
class _ViT_Split_Input_2D(torch.autograd.Function):
|
||||
class AllGatherLast(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
batch_size: int,
|
||||
summa_dim: int,
|
||||
col_parallel_mode: ParallelMode) -> Tensor:
|
||||
# inputs: [b, s, h/q]
|
||||
# output: [b/q, s, h/q]
|
||||
|
||||
ctx.BATCH_SIZE = batch_size
|
||||
ctx.summa_dim = summa_dim
|
||||
ctx.col_parallel_mode = col_parallel_mode
|
||||
row_rank = gpc.get_local_rank(col_parallel_mode)
|
||||
output = torch.chunk(inputs, summa_dim, dim=0)[row_rank]
|
||||
output = output.clone()
|
||||
return output
|
||||
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
|
||||
|
||||
last_dim = summa_dim * inputs.size(-1)
|
||||
outputs_shape = (last_dim,) + inputs.shape[:-1]
|
||||
outputs = torch.empty(
|
||||
outputs_shape, dtype=inputs.dtype, device=get_current_device())
|
||||
dist.all_gather(
|
||||
list(outputs.chunk(summa_dim, dim=0)),
|
||||
inputs.permute(2, 0, 1).contiguous(),
|
||||
group=gpc.get_group(col_parallel_mode)
|
||||
)
|
||||
outputs = outputs.permute(1, 2, 0).contiguous()
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
# output_grad: [b/q, s, h/q]
|
||||
# grads: [b, s, h/q]
|
||||
grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
|
||||
grads = torch.empty(grads_shape,
|
||||
dtype=output_grad.dtype,
|
||||
device=get_current_device())
|
||||
dist.all_gather(list(grads.chunk(ctx.summa_dim, dim=0)),
|
||||
output_grad.contiguous(),
|
||||
group=gpc.get_group(ctx.col_parallel_mode))
|
||||
return grads, None, None, None
|
||||
grad = output_grad.chunk(ctx.summa_dim, dim=-1)[ctx.row_rank]
|
||||
return grad.contiguous(), None, None
|
||||
|
||||
|
||||
class SplitFirst(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
summa_dim: int,
|
||||
col_parallel_mode: ParallelMode) -> Tensor:
|
||||
ctx.summa_dim = summa_dim
|
||||
ctx.batch_size = inputs.size(0)
|
||||
ctx.para_mode = col_parallel_mode
|
||||
row_rank = gpc.get_local_rank(col_parallel_mode)
|
||||
|
||||
outputs = inputs.chunk(summa_dim, dim=0)[row_rank]
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
|
||||
grad = torch.empty(
|
||||
grad_shape, dtype=output_grad.dtype, device=get_current_device())
|
||||
dist.all_gather(
|
||||
list(grad.chunk(ctx.summa_dim, dim=0)),
|
||||
output_grad.contiguous(),
|
||||
group=gpc.get_group(ctx.para_mode)
|
||||
)
|
||||
return grad, None, None
|
||||
|
@ -5,19 +5,21 @@ import math
|
||||
|
||||
import torch
|
||||
from torch import nn as nn, Tensor, distributed as dist
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from colossalai.context import seed, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer._common_utils import divide, ACT2FN
|
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import _ViT_Split_Input_2D
|
||||
from colossalai.core import global_context as gpc
|
||||
from ._operation import AllGatherLast, SplitFirst
|
||||
from .layers import Linear2D
|
||||
from .._common_utils import set_tensor_parallel_attribute
|
||||
from .._common_utils import set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..fused_bias_gelu import bias_gelu_impl
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@ -44,8 +46,8 @@ class ViTMLP2D(ParallelLayer):
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False
|
||||
):
|
||||
checkpoint: bool = False,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
|
||||
assert_summa_initialization()
|
||||
@ -53,27 +55,40 @@ class ViTMLP2D(ParallelLayer):
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.checkpoint = checkpoint
|
||||
assert weight_init in ('torch', 'jax')
|
||||
|
||||
if act_func == 'fused_gelu':
|
||||
self.act = bias_gelu_impl
|
||||
skip_dense_1_add_bias = True
|
||||
else:
|
||||
self.act = ACT2FN[act_func]
|
||||
skip_dense_1_add_bias = False
|
||||
|
||||
# Project to mlp_ratio * h.
|
||||
self.dense_1 = Linear2D(
|
||||
self.in_features,
|
||||
self.mlp_ratio * self.in_features,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init, init_bias=weight_init,
|
||||
skip_bias_add=skip_dense_1_add_bias
|
||||
)
|
||||
|
||||
self.act = ACT2FN[act_func]
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear2D(
|
||||
self.mlp_ratio * self.in_features,
|
||||
self.in_features,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init, init_bias=weight_init
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output)
|
||||
if self.act == bias_gelu_impl:
|
||||
intermediate_output, bias = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output, bias)
|
||||
else:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
intermediate_output = self.dropout(intermediate_output)
|
||||
@ -117,8 +132,8 @@ class ViTSelfAttention2D(ParallelLayer):
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False
|
||||
):
|
||||
checkpoint: bool = False,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
|
||||
assert_summa_initialization()
|
||||
@ -128,17 +143,24 @@ class ViTSelfAttention2D(ParallelLayer):
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.checkpoint = checkpoint
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
self.init_bias = 'zero'
|
||||
else:
|
||||
self.init_bias = weight_init
|
||||
|
||||
self.query_key_value = Linear2D(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init, init_bias=self.init_bias
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear2D(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init, init_bias=self.init_bias
|
||||
)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
@ -146,7 +168,7 @@ class ViTSelfAttention2D(ParallelLayer):
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(
|
||||
@ -155,7 +177,7 @@ class ViTSelfAttention2D(ParallelLayer):
|
||||
attention_scores = torch.matmul(
|
||||
query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
math.sqrt(self.attention_head_size)
|
||||
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
|
||||
@ -165,7 +187,7 @@ class ViTSelfAttention2D(ParallelLayer):
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
new_context_layer_shape = context_layer.size()[
|
||||
:-2] + (self.all_head_size,)
|
||||
:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
@ -199,14 +221,22 @@ class ViTHead2D(ParallelLayer):
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=None,
|
||||
):
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
assert_summa_initialization()
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
self.init_weight = 'zero'
|
||||
self.init_bias = 'zero'
|
||||
else:
|
||||
self.init_weight = weight_init
|
||||
self.init_bias = weight_init
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.linear = Linear2D(
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=dtype,
|
||||
init_weight=self.init_weight, init_bias=self.init_bias
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
@ -236,7 +266,8 @@ class ViTPatchEmbedding2D(ParallelLayer):
|
||||
patch_size,
|
||||
embed_dim,
|
||||
in_chans=3,
|
||||
flatten=True):
|
||||
flatten=True,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
@ -249,39 +280,28 @@ class ViTPatchEmbedding2D(ParallelLayer):
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.embed_dim = embed_dim // self.summa_dim
|
||||
self.embed_dim = embed_dim // (self.summa_dim ** 2)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
# ensure the partitions are initialized differently
|
||||
self.proj = nn.Conv2d(in_chans,
|
||||
self.embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size
|
||||
stride=patch_size,
|
||||
device=get_current_device()
|
||||
)
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
# sync
|
||||
self._broadcast_conv_params()
|
||||
self.proj.weight.register_hook(self._sync_grad_during_backward)
|
||||
self.proj.bias.register_hook(self._sync_grad_during_backward)
|
||||
if weight_init == 'jax':
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
|
||||
std = math.sqrt(1.0 / fan_in)
|
||||
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
set_tensor_parallel_attribute(self.proj.weight)
|
||||
set_tensor_parallel_attribute(self.proj.bias)
|
||||
|
||||
def _broadcast_conv_params(self) -> None:
|
||||
self.to(get_current_device())
|
||||
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
dist.broadcast(self.proj.weight, src=ranks_in_col[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
|
||||
dist.broadcast(self.proj.bias, src=ranks_in_col[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
|
||||
|
||||
def _sync_grad_during_backward(self, grad: Tensor) -> None:
|
||||
dist.all_reduce(grad, group=gpc.get_group(
|
||||
ParallelMode.PARALLEL_2D_COL))
|
||||
grad = grad / self.summa_dim
|
||||
return grad
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, C, H, W = x.shape
|
||||
@ -293,6 +313,24 @@ class ViTPatchEmbedding2D(ParallelLayer):
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTInputSplitter2D(ParallelLayer):
|
||||
"""Split the input tensor for 2D parallel Vision Transformer
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = AllGatherLast.apply(
|
||||
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
x = SplitFirst.apply(
|
||||
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTTokenFuser2D(ParallelLayer):
|
||||
"""
|
||||
@ -328,64 +366,32 @@ class ViTTokenFuser2D(ParallelLayer):
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(
|
||||
1, 1, self.embed_dim // self.summa_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(
|
||||
1, self.num_patches + 1, self.embed_dim // self.summa_dim))
|
||||
(1, 1, self.embed_dim // (self.summa_dim ** 2)),
|
||||
device=get_current_device()))
|
||||
self.pos_embed = nn.Parameter(torch.empty(
|
||||
(1, self.num_patches + 1, self.embed_dim // (self.summa_dim ** 2)),
|
||||
device=get_current_device()))
|
||||
with seed(ParallelMode.TENSOR):
|
||||
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
# move to cuda before broadcast
|
||||
self.to(get_current_device())
|
||||
|
||||
# sync param in both forward and backward
|
||||
_cls_token = self.cls_token.view(-1)
|
||||
_pos_embed = self.pos_embed.view(-1)
|
||||
self._param = torch.cat([_cls_token, _pos_embed], dim=0)
|
||||
|
||||
self._broadcast_params(self._param)
|
||||
self._param.register_hook(self._sync_grad_hook)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
set_tensor_parallel_attribute(self.cls_token)
|
||||
set_tensor_parallel_attribute(self.pos_embed)
|
||||
|
||||
def _broadcast_params(self, param) -> None:
|
||||
" broadcast to all column ranks for data consistency "
|
||||
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
|
||||
col_group = gpc.get_group(ParallelMode.PARALLEL_2D_COL)
|
||||
dist.broadcast(param, src=ranks_in_col[0],
|
||||
group=col_group)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> None:
|
||||
dist.all_reduce(grad, group=gpc.get_group(
|
||||
ParallelMode.PARALLEL_2D_COL))
|
||||
grad = grad / self.summa_dim
|
||||
return grad
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
cls_token = AllGatherLast.apply(
|
||||
self.cls_token, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
cls_token = cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
pos_embed = AllGatherLast.apply(
|
||||
self.pos_embed, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
x = x + pos_embed
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
x = self.pos_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTInputSplitter2D(ParallelLayer):
|
||||
"""Split the input tensor for 2D parallel Vision Transformer
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
batch_size = x.size(0)
|
||||
return _ViT_Split_Input_2D.apply(
|
||||
x,
|
||||
batch_size,
|
||||
self.summa_dim,
|
||||
ParallelMode.PARALLEL_2D_COL
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D
|
||||
from ._utils import get_summa_dim_from_env, assert_summa_initialization
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ..base_layer import ParallelLayer
|
||||
|
||||
|
||||
@ -36,8 +36,9 @@ class Linear2D(ParallelLayer):
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype=None,
|
||||
skip_bias_add: bool = False
|
||||
):
|
||||
skip_bias_add: bool = False,
|
||||
init_weight='torch',
|
||||
init_bias='torch'):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
@ -72,31 +73,45 @@ class Linear2D(ParallelLayer):
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
# initialize parameters
|
||||
self.reset_parameters()
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(init_weight, init_bias)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute(self.weight)
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute(self.bias)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
def reset_parameters(self, init_weight, init_bias) -> None:
|
||||
assert init_weight in ('torch', 'jax', 'zero')
|
||||
assert init_bias in ('torch', 'jax', 'zero')
|
||||
# setting
|
||||
fan_in = self.in_features
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
||||
# init weight
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
with seed(ParallelMode.TENSOR):
|
||||
if init_weight == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -bound, bound)
|
||||
elif init_weight == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -a, a)
|
||||
elif init_weight == 'zero':
|
||||
init.zeros_(self.weight)
|
||||
|
||||
# init bias
|
||||
if self.bias is not None:
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
with seed(ParallelMode.TENSOR):
|
||||
if init_bias == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
elif init_bias == 'jax':
|
||||
init.normal_(self.bias, std=1e-6)
|
||||
elif init_bias == 'zero':
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/q, n/q, k/q]
|
||||
@ -192,28 +207,19 @@ class LayerNorm2D(ParallelLayer):
|
||||
# create parameters
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
|
||||
if self.row_rank == 0:
|
||||
self.gamma = Parameter(torch.ones(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
self.beta = Parameter(torch.zeros(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
else:
|
||||
self.gamma = Parameter(torch.tensor(
|
||||
1.0,
|
||||
requires_grad=True,
|
||||
**factory_kwargs))
|
||||
self.beta = Parameter(torch.tensor(
|
||||
1.0,
|
||||
requires_grad=True,
|
||||
**factory_kwargs))
|
||||
self.gamma = Parameter(torch.ones(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
self.beta = Parameter(torch.zeros(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute(self.gamma)
|
||||
set_tensor_parallel_attribute(self.beta)
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with torch.no_grad():
|
||||
|
@ -1,11 +1,10 @@
|
||||
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Sum_2p5D, Add_Bias_2p5D
|
||||
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Add_Bias_2p5D
|
||||
from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D
|
||||
from ._vit import (ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D,
|
||||
ViTInputSplitter2p5D)
|
||||
from ._vit import ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D, ViTInputSplitter2p5D
|
||||
from .layers import Linear2p5D, LayerNorm2p5D
|
||||
|
||||
__all__ = [
|
||||
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Sum_2p5D', 'Add_Bias_2p5D',
|
||||
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Add_Bias_2p5D',
|
||||
'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D',
|
||||
'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D',
|
||||
'ViTInputSplitter2p5D',
|
||||
|
@ -6,7 +6,8 @@ from torch import Tensor
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device, empty_cache
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
def get_parallel_group(parallel_mode: ParallelMode):
|
||||
@ -26,18 +27,17 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
dep_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
dep_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
@ -49,41 +49,43 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
||||
assert A.shape[-1] == B.shape[-2], \
|
||||
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
|
||||
|
||||
empty_cache()
|
||||
if ctx:
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
A_shape = A.shape
|
||||
A = A.reshape((-1, A_shape[-1]))
|
||||
A = A.reshape((-1, A_shape[-1])).contiguous()
|
||||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
B = B.reshape((-1, B_shape[-1])).contiguous()
|
||||
C_shape = (A.shape[0], B.shape[-1])
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
A_temp = A.clone()
|
||||
B_temp = B.clone()
|
||||
src_a = i + row_rank * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(A_temp, src=src_a,
|
||||
group=get_parallel_group(row_parallel_mode))
|
||||
src_b = col_rank + i * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(B_temp, src=src_b,
|
||||
group=get_parallel_group(col_parallel_mode))
|
||||
torch.addmm(C, A_temp, B_temp, out=C)
|
||||
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
|
||||
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
|
||||
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
|
||||
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
|
||||
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
|
||||
op_a.wait()
|
||||
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
|
||||
for op in [op_a, op_b]:
|
||||
op.wait()
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
src_a = i + tesseract_dim * row_rank
|
||||
src_b = i + tesseract_dim * col_rank
|
||||
src_a = src_a % tesseract_dim
|
||||
src_b = src_b % tesseract_dim
|
||||
A_temp = A_list[src_a]
|
||||
B_temp = B_list[src_b]
|
||||
torch.addmm(C, A_temp, B_temp, out=C)
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
ctx.tesseract_dim = tesseract_dim
|
||||
ctx.tesseract_dep = tesseract_dep
|
||||
ctx.row_rank = row_rank
|
||||
ctx.col_rank = col_rank
|
||||
ctx.dep_rank = dep_rank
|
||||
ctx.row_parallel_mode = row_parallel_mode
|
||||
ctx.col_parallel_mode = col_parallel_mode
|
||||
ctx.dep_parallel_mode = dep_parallel_mode
|
||||
ctx.A_shape = A_shape
|
||||
ctx.B_shape = B_shape
|
||||
ctx.data_parallel_rank = data_parallel_rank
|
||||
@ -94,34 +96,32 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
A_grad = Matmul_ABT_2p5D.forward(
|
||||
None,
|
||||
output_grad, B,
|
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.dep_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2p5D.forward(
|
||||
None,
|
||||
A, output_grad,
|
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.dep_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_ABT_2p5D.apply(
|
||||
output_grad, B,
|
||||
ctx.tesseract_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2p5D.apply(
|
||||
A, output_grad,
|
||||
ctx.tesseract_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -130,18 +130,17 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
dep_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
dep_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
@ -151,7 +150,6 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||
assert A.shape[-1] == B.shape[-1], \
|
||||
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
|
||||
|
||||
empty_cache()
|
||||
if ctx:
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
@ -180,13 +178,11 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||
|
||||
if ctx:
|
||||
ctx.tesseract_dim = tesseract_dim
|
||||
ctx.tesseract_dep = tesseract_dep
|
||||
ctx.row_rank = row_rank
|
||||
ctx.col_rank = col_rank
|
||||
ctx.dep_rank = dep_rank
|
||||
ctx.row_parallel_mode = row_parallel_mode
|
||||
ctx.col_parallel_mode = col_parallel_mode
|
||||
ctx.dep_parallel_mode = dep_parallel_mode
|
||||
ctx.A_shape = A_shape
|
||||
ctx.B_shape = B_shape
|
||||
ctx.data_parallel_rank = data_parallel_rank
|
||||
@ -197,34 +193,32 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
A_grad = Matmul_AB_2p5D.forward(
|
||||
None,
|
||||
output_grad, B,
|
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.dep_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2p5D.forward(
|
||||
None,
|
||||
output_grad, A,
|
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.dep_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_AB_2p5D.apply(
|
||||
output_grad, B,
|
||||
ctx.tesseract_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2p5D.apply(
|
||||
output_grad, A,
|
||||
ctx.tesseract_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -233,18 +227,17 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
dep_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
dep_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
@ -253,7 +246,6 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||
assert A.shape[-2] == B.shape[-2], \
|
||||
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
|
||||
|
||||
empty_cache()
|
||||
if ctx:
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
@ -284,13 +276,11 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||
|
||||
if ctx:
|
||||
ctx.tesseract_dim = tesseract_dim
|
||||
ctx.tesseract_dep = tesseract_dep
|
||||
ctx.row_rank = row_rank
|
||||
ctx.col_rank = col_rank
|
||||
ctx.dep_rank = dep_rank
|
||||
ctx.row_parallel_mode = row_parallel_mode
|
||||
ctx.col_parallel_mode = col_parallel_mode
|
||||
ctx.dep_parallel_mode = dep_parallel_mode
|
||||
ctx.A_shape = A_shape
|
||||
ctx.B_shape = B_shape
|
||||
ctx.data_parallel_rank = data_parallel_rank
|
||||
@ -301,34 +291,32 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
A_grad = Matmul_ABT_2p5D.forward(
|
||||
None,
|
||||
B, output_grad,
|
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.dep_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_AB_2p5D.forward(
|
||||
None,
|
||||
A, output_grad,
|
||||
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.dep_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_ABT_2p5D.apply(
|
||||
B, output_grad,
|
||||
ctx.tesseract_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_AB_2p5D.apply(
|
||||
A, output_grad,
|
||||
ctx.tesseract_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -337,18 +325,16 @@ class Add_Bias_2p5D(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
input: Tensor,
|
||||
bias: Tensor,
|
||||
output_size_per_partition: int,
|
||||
tesseract_dim: int,
|
||||
tesseract_dep: int,
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
dep_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
dep_parallel_mode: ParallelMode,
|
||||
skip_bias_add: bool,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
@ -371,10 +357,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
|
||||
ctx.col_rank = col_rank
|
||||
ctx.dep_rank = dep_rank
|
||||
ctx.tesseract_dim = tesseract_dim
|
||||
ctx.tesseract_dep = tesseract_dep
|
||||
ctx.row_parallel_mode = row_parallel_mode
|
||||
ctx.col_parallel_mode = col_parallel_mode
|
||||
ctx.dep_parallel_mode = dep_parallel_mode
|
||||
ctx.bias = skip_bias_add
|
||||
ctx.data_parallel_rank = data_parallel_rank
|
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank
|
||||
@ -388,15 +371,13 @@ class Add_Bias_2p5D(torch.autograd.Function):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
row_rank = ctx.row_rank
|
||||
col_rank = ctx.col_rank
|
||||
dep_rank = ctx.dep_rank
|
||||
tesseract_dim = ctx.tesseract_dim
|
||||
tesseract_dep = ctx.tesseract_dep
|
||||
row_parallel_mode = ctx.row_parallel_mode
|
||||
col_parallel_mode = ctx.col_parallel_mode
|
||||
dep_parallel_mode = ctx.dep_parallel_mode
|
||||
data_parallel_rank = ctx.data_parallel_rank
|
||||
pipeline_parallel_rank = ctx.pipeline_parallel_rank
|
||||
pipeline_parallel_size = ctx.pipeline_parallel_size
|
||||
@ -428,29 +409,25 @@ class Add_Bias_2p5D(torch.autograd.Function):
|
||||
|
||||
class _LayerNorm_2p5D(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx: Any,
|
||||
input: Tensor,
|
||||
E_x: Tensor,
|
||||
Var_x: Tensor,
|
||||
hidden_size: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
dep_parallel_mode: ParallelMode) -> Tensor:
|
||||
row_parallel_mode: ParallelMode) -> Tensor:
|
||||
input = input - E_x
|
||||
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
|
||||
ctx.hidden_size = hidden_size
|
||||
output = input * Var_x
|
||||
ctx.save_for_backward(output, Var_x)
|
||||
ctx.row_parallel_mode = row_parallel_mode
|
||||
ctx.col_parallel_mode = col_parallel_mode
|
||||
ctx.dep_parallel_mode = dep_parallel_mode
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad):
|
||||
row_parallel_mode = ctx.row_parallel_mode
|
||||
col_parallel_mode = ctx.col_parallel_mode
|
||||
dep_parallel_mode = ctx.dep_parallel_mode
|
||||
x, Var_x = ctx.saved_tensors
|
||||
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
|
||||
with torch.no_grad():
|
||||
@ -473,63 +450,122 @@ class _LayerNorm_2p5D(torch.autograd.Function):
|
||||
return input_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Sum_2p5D(torch.autograd.Function):
|
||||
"""Compute the sum of input tensors
|
||||
"""
|
||||
# class Sum_2p5D(torch.autograd.Function):
|
||||
# """Compute the sum of input tensors
|
||||
# """
|
||||
|
||||
# @staticmethod
|
||||
# def forward(ctx,
|
||||
# inputs,
|
||||
# dim,
|
||||
# tesseract_dim,
|
||||
# row_parallel_mode,
|
||||
# keepdim=False):
|
||||
# # input: [b/q, s, h/q]
|
||||
# ctx.save_for_backward(inputs)
|
||||
# # sum: [b/q, s]
|
||||
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
|
||||
# torch.distributed.all_reduce(
|
||||
# out, group=gpc.get_group(row_parallel_mode))
|
||||
# return out
|
||||
|
||||
# @staticmethod
|
||||
# def backward(ctx, output_grad):
|
||||
# with torch.no_grad():
|
||||
# inputs = ctx.saved_tensors
|
||||
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
|
||||
# return input_grad, None, None, None, None, None
|
||||
|
||||
|
||||
# class _ViT_Split_2p5D(torch.autograd.Function):
|
||||
# @staticmethod
|
||||
# @custom_fwd(cast_inputs=torch.float16)
|
||||
# def forward(ctx, inputs, batch_size,
|
||||
# tesseract_dim, tesseract_dep,
|
||||
# xz_parallel_mode):
|
||||
# # inputs: [b, s, h/q]
|
||||
# # output: [b/dq, s, h/q]
|
||||
|
||||
# ctx.BATCH_SIZE = batch_size
|
||||
# ctx.tesseract_dim = tesseract_dim
|
||||
# ctx.tesseract_dep = tesseract_dep
|
||||
# ctx.xz_parallel_mode = xz_parallel_mode
|
||||
# xz_rank = gpc.get_local_rank(xz_parallel_mode)
|
||||
# output = torch.chunk(inputs, tesseract_dep *
|
||||
# tesseract_dim, dim=0)[xz_rank]
|
||||
# output = output.clone()
|
||||
# return output
|
||||
|
||||
# @staticmethod
|
||||
# @custom_bwd
|
||||
# def backward(ctx, output_grad):
|
||||
# # output_grad: [b/dq, s, h/q]
|
||||
# # grads: [b, s, h/q]
|
||||
# # *
|
||||
# grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
|
||||
# grads = torch.empty(grads_shape,
|
||||
# dtype=output_grad.dtype,
|
||||
# device=get_current_device())
|
||||
# dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
|
||||
# output_grad.contiguous(),
|
||||
# group=get_parallel_group(ctx.xz_parallel_mode))
|
||||
# return grads, None, None, None, None
|
||||
|
||||
class AllGatherLast(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
inputs,
|
||||
dim,
|
||||
tesseract_dim,
|
||||
row_parallel_mode,
|
||||
keepdim=False):
|
||||
# input: [b/q, s, h/q]
|
||||
empty_cache()
|
||||
ctx.save_for_backward(inputs)
|
||||
# sum: [b/q, s]
|
||||
out = torch.sum(inputs, dim=dim, keepdim=keepdim)
|
||||
torch.distributed.all_reduce(
|
||||
out, group=gpc.get_group(row_parallel_mode))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
with torch.no_grad():
|
||||
inputs = ctx.saved_tensors
|
||||
input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
|
||||
return input_grad, None, None, None, None, None
|
||||
|
||||
|
||||
class _ViT_Split_2p5D(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, batch_size,
|
||||
tesseract_dim, tesseract_dep,
|
||||
xz_parallel_mode):
|
||||
# inputs: [b, s, h/q]
|
||||
# output: [b/dq, s, h/q]
|
||||
empty_cache()
|
||||
|
||||
ctx.batch_size = batch_size
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
tesseract_dim: int,
|
||||
col_parallel_mode: ParallelMode) -> Tensor:
|
||||
ctx.tesseract_dim = tesseract_dim
|
||||
ctx.tesseract_dep = tesseract_dep
|
||||
ctx.xz_parallel_mode = xz_parallel_mode
|
||||
xz_rank = gpc.get_local_rank(xz_parallel_mode)
|
||||
output = torch.chunk(inputs, tesseract_dep *
|
||||
tesseract_dim, dim=0)[xz_rank]
|
||||
output = output.clone()
|
||||
return output
|
||||
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
|
||||
|
||||
last_dim = tesseract_dim * inputs.size(-1)
|
||||
outputs_shape = (last_dim,) + inputs.shape[:-1]
|
||||
outputs = torch.empty(
|
||||
outputs_shape, dtype=inputs.dtype, device=get_current_device())
|
||||
dist.all_gather(
|
||||
list(outputs.chunk(tesseract_dim, dim=0)),
|
||||
inputs.permute(2, 0, 1).contiguous(),
|
||||
group=gpc.get_group(col_parallel_mode)
|
||||
)
|
||||
outputs = outputs.permute(1, 2, 0).contiguous()
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
# output_grad: [b/dq, s, h/q]
|
||||
# grads: [b, s, h/q]
|
||||
# *
|
||||
grads_shape = (ctx.batch_size,) + output_grad.shape[1:]
|
||||
grads = torch.empty(grads_shape,
|
||||
dtype=output_grad.dtype,
|
||||
device=get_current_device())
|
||||
dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
|
||||
output_grad.contiguous(),
|
||||
group=get_parallel_group(ctx.xz_parallel_mode))
|
||||
return grads, None, None, None, None
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
grad = output_grad.chunk(ctx.tesseract_dim, dim=-1)[ctx.row_rank]
|
||||
return grad.contiguous(), None, None
|
||||
|
||||
|
||||
class SplitFirst(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
tesseract_dim: int,
|
||||
col_parallel_mode: ParallelMode) -> Tensor:
|
||||
ctx.tesseract_dim = tesseract_dim
|
||||
ctx.batch_size = inputs.size(0)
|
||||
ctx.para_mode = col_parallel_mode
|
||||
row_rank = gpc.get_local_rank(col_parallel_mode)
|
||||
|
||||
outputs = inputs.chunk(tesseract_dim, dim=0)[row_rank]
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
|
||||
grad = torch.empty(
|
||||
grad_shape, dtype=output_grad.dtype, device=get_current_device())
|
||||
dist.all_gather(
|
||||
list(grad.chunk(ctx.tesseract_dim, dim=0)),
|
||||
output_grad.contiguous(),
|
||||
group=gpc.get_group(ctx.para_mode)
|
||||
)
|
||||
return grad, None, None
|
@ -12,10 +12,11 @@ from ._utils import assert_tesseract_initialization, \
|
||||
get_tesseract_dim_dep_from_env
|
||||
from .layers import Linear2p5D, LayerNorm2p5D
|
||||
from .._common_utils import ACT2FN
|
||||
from ..base_layer import ParallelLayer
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerMLP2p5D(nn.Module):
|
||||
class TransformerMLP2p5D(ParallelLayer):
|
||||
"""
|
||||
MLP will take the input with h hidden state, project it to mlp_ratio * h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
@ -36,21 +37,24 @@ class TransformerMLP2p5D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
mlp_ratio: int,
|
||||
mlp_ratio: int = 4.0,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
skip_bias_add: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.in_features = in_features
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
# Project to h * mlp_ratio.
|
||||
self.dense_1 = Linear2p5D(
|
||||
in_features,
|
||||
mlp_ratio * in_features,
|
||||
dtype=dtype
|
||||
int(mlp_ratio * in_features),
|
||||
dtype=dtype,
|
||||
skip_bias_add=skip_bias_add
|
||||
)
|
||||
|
||||
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
|
||||
@ -59,24 +63,34 @@ class TransformerMLP2p5D(nn.Module):
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear2p5D(
|
||||
mlp_ratio * in_features,
|
||||
int(mlp_ratio * in_features),
|
||||
in_features,
|
||||
dtype=dtype
|
||||
dtype=dtype,
|
||||
skip_bias_add=skip_bias_add
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
self.layernorm = LayerNorm2p5D(in_features, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
intermediate_output = self.dense_1(x)
|
||||
if self.skip_bias_add:
|
||||
intermediate_output, _ = self.dense_1(x)
|
||||
else:
|
||||
intermediate_output = self.dense_1(x)
|
||||
|
||||
intermediate_output = self.activation_func(intermediate_output)
|
||||
output = self.dense_2(intermediate_output)
|
||||
|
||||
if self.skip_bias_add:
|
||||
output, _ = self.dense_2(intermediate_output)
|
||||
else:
|
||||
output = self.dense_2(intermediate_output)
|
||||
|
||||
output = self.dropout(output)
|
||||
output = self.layernorm(x + output)
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerSelfAttention2p5D(nn.Module):
|
||||
class TransformerSelfAttention2p5D(ParallelLayer):
|
||||
"""Self attention layer for 2.5D parallel Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
@ -92,10 +106,10 @@ class TransformerSelfAttention2p5D(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -127,7 +141,7 @@ class TransformerSelfAttention2p5D(nn.Module):
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(
|
||||
@ -136,7 +150,7 @@ class TransformerSelfAttention2p5D(nn.Module):
|
||||
attention_scores = torch.matmul(
|
||||
query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
math.sqrt(self.attention_head_size)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
@ -144,7 +158,7 @@ class TransformerSelfAttention2p5D(nn.Module):
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[
|
||||
:-2] + (self.all_head_size,)
|
||||
:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
@ -155,7 +169,7 @@ class TransformerSelfAttention2p5D(nn.Module):
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerLayer2p5D(nn.Module):
|
||||
class TransformerLayer2p5D(ParallelLayer):
|
||||
"""Transformer layer which contains a self-attention layer and a MLP layer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
@ -175,10 +189,10 @@ class TransformerLayer2p5D(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
act_func='gelu',
|
||||
mlp_ratio=4,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.,
|
||||
hidden_dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
|
@ -5,22 +5,25 @@ import math
|
||||
|
||||
import torch
|
||||
from torch import nn as nn, Tensor, distributed as dist
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context import seed, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import _ViT_Split_2p5D
|
||||
from ._operation import AllGatherLast, SplitFirst
|
||||
from ._utils import assert_tesseract_initialization, \
|
||||
get_tesseract_dim_dep_from_env
|
||||
from .layers import Linear2p5D
|
||||
from .._common_utils import ACT2FN, divide, CheckpointModule
|
||||
from .._common_utils import set_tensor_parallel_attribute
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..fused_bias_gelu import bias_gelu_impl
|
||||
from .._common_utils import (ACT2FN, divide, to_2tuple,
|
||||
set_tensor_parallel_attribute_by_partition)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTMLP2p5D(CheckpointModule):
|
||||
class ViTMLP2p5D(ParallelLayer):
|
||||
"""MLP layer for 2.5D parallel Vision Transformer
|
||||
|
||||
:param in_features: size of each input sample
|
||||
@ -43,19 +46,32 @@ class ViTMLP2p5D(CheckpointModule):
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False
|
||||
checkpoint: bool = False,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__(checkpoint=checkpoint)
|
||||
super().__init__()
|
||||
|
||||
assert_tesseract_initialization()
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.checkpoint = checkpoint
|
||||
assert weight_init in ('torch', 'jax')
|
||||
|
||||
if act_func == 'fused_gelu':
|
||||
self.act = bias_gelu_impl
|
||||
skip_dense_1_add_bias = True
|
||||
else:
|
||||
self.act = ACT2FN[act_func]
|
||||
skip_dense_1_add_bias = False
|
||||
|
||||
# Project to mlp_ratio * h.
|
||||
self.dense_1 = Linear2p5D(
|
||||
self.in_features,
|
||||
self.mlp_ratio * self.in_features,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init,
|
||||
init_bias=weight_init,
|
||||
skip_bias_add=skip_dense_1_add_bias
|
||||
)
|
||||
|
||||
self.act = ACT2FN[act_func]
|
||||
@ -65,20 +81,39 @@ class ViTMLP2p5D(CheckpointModule):
|
||||
self.mlp_ratio * self.in_features,
|
||||
self.in_features,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init,
|
||||
init_bias=weight_init
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output)
|
||||
intermediate_output = self.dropout(intermediate_output)
|
||||
if self.act == bias_gelu_impl:
|
||||
intermediate_output, bias = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output, bias)
|
||||
else:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
intermediate_output = self.dropout(intermediate_output)
|
||||
output = self.dense_2(intermediate_output)
|
||||
output = self.dropout(output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTSelfAttention2p5D(CheckpointModule):
|
||||
class ViTSelfAttention2p5D(ParallelLayer):
|
||||
"""Self-attention layer for 2.5D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
@ -101,9 +136,10 @@ class ViTSelfAttention2p5D(CheckpointModule):
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=None,
|
||||
checkpoint: bool = False
|
||||
checkpoint: bool = False,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__(checkpoint=checkpoint)
|
||||
super().__init__()
|
||||
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
@ -112,19 +148,30 @@ class ViTSelfAttention2p5D(CheckpointModule):
|
||||
num_attention_heads, self.tesseract_dim) # *
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.checkpoint = checkpoint
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
self.init_bias = 'zero'
|
||||
else:
|
||||
self.init_bias = weight_init
|
||||
|
||||
self.query_key_value = Linear2p5D(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init,
|
||||
init_bias=self.init_bias
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear2p5D(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init,
|
||||
init_bias=self.init_bias
|
||||
)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
@ -140,8 +187,10 @@ class ViTSelfAttention2p5D(CheckpointModule):
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
@ -150,12 +199,22 @@ class ViTSelfAttention2p5D(CheckpointModule):
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
output = self.dropout(output)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead2p5D(nn.Module):
|
||||
class ViTHead2p5D(ParallelLayer):
|
||||
"""Output layer for 2.5D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
@ -170,13 +229,24 @@ class ViTHead2p5D(nn.Module):
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=None,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
self.init_weight = 'zero'
|
||||
self.init_bias = 'zero'
|
||||
else:
|
||||
self.init_weight = weight_init
|
||||
self.init_bias = weight_init
|
||||
|
||||
self.linear = Linear2p5D(
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=dtype,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
@ -186,7 +256,7 @@ class ViTHead2p5D(nn.Module):
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTPatchEmbedding2p5D(nn.Module):
|
||||
class ViTPatchEmbedding2p5D(ParallelLayer):
|
||||
""" 2.5D Image to Patch Embedding
|
||||
|
||||
:param img_size: iamge size
|
||||
@ -206,7 +276,8 @@ class ViTPatchEmbedding2p5D(nn.Module):
|
||||
patch_size,
|
||||
embed_dim,
|
||||
in_chans=3,
|
||||
flatten=True):
|
||||
flatten=True,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
@ -219,34 +290,28 @@ class ViTPatchEmbedding2p5D(nn.Module):
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.embed_dim = embed_dim // self.tesseract_dim # *
|
||||
self.embed_dim = embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2) # *
|
||||
|
||||
self.proj = nn.Conv2d(in_chans,
|
||||
self.embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.proj = nn.Conv2d(in_chans,
|
||||
self.embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
device=get_current_device()
|
||||
)
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
# move self to cuda before sync
|
||||
self.to(get_current_device())
|
||||
if weight_init == 'jax':
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
|
||||
std = math.sqrt(1.0 / fan_in)
|
||||
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
# sync
|
||||
self._broadcast_conv_params()
|
||||
self.proj.weight.register_hook(self._sync_grad_during_backward)
|
||||
self.proj.bias.register_hook(self._sync_grad_during_backward)
|
||||
|
||||
def _broadcast_conv_params(self) -> None:
|
||||
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
|
||||
dist.broadcast(self.proj.weight, src=xz_rank[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
|
||||
dist.broadcast(self.proj.bias, src=xz_rank[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
|
||||
|
||||
def _sync_grad_during_backward(self, grad: Tensor) -> None:
|
||||
dist.all_reduce(grad, group=gpc.get_group(
|
||||
ParallelMode.PARALLEL_2P5D_XZ))
|
||||
grad = grad / self.tesseract_dim / self.tesseract_dep # *
|
||||
return grad
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, C, H, W = x.shape
|
||||
@ -259,7 +324,25 @@ class ViTPatchEmbedding2p5D(nn.Module):
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTTokenFuser2p5D(nn.Module):
|
||||
class ViTInputSplitter2p5D(ParallelLayer):
|
||||
"""Split the input tensor for 2D parallel Vision Transformer
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = AllGatherLast.apply(
|
||||
x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
x = SplitFirst.apply(
|
||||
x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTTokenFuser2p5D(ParallelLayer):
|
||||
"""
|
||||
Fuse cls token and pos embedding to the input
|
||||
|
||||
@ -293,59 +376,46 @@ class ViTTokenFuser2p5D(nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(
|
||||
1, 1, self.embed_dim // self.tesseract_dim)) # *
|
||||
self.pos_embed = nn.Parameter(torch.zeros(
|
||||
1, self.num_patches + 1, self.embed_dim // self.tesseract_dim)) # *
|
||||
(1, 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
|
||||
device=get_current_device()))
|
||||
self.pos_embed = nn.Parameter(torch.empty(
|
||||
(1, self.num_patches + 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
|
||||
device=get_current_device()))
|
||||
with seed(ParallelMode.TENSOR):
|
||||
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
# move to cuda before broadcast
|
||||
self.to(get_current_device())
|
||||
|
||||
self._broadcast_params()
|
||||
self.cls_token.register_hook(self._sync_grad_hook)
|
||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
set_tensor_parallel_attribute(self.cls_token)
|
||||
set_tensor_parallel_attribute(self.pos_embed)
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
|
||||
|
||||
def _broadcast_params(self) -> None:
|
||||
def _broadcast_params(self, param) -> None:
|
||||
" broadcast to all column ranks for data consistency "
|
||||
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
|
||||
dist.broadcast(self.cls_token, src=xz_rank[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
|
||||
dist.broadcast(self.pos_embed, src=xz_rank[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
|
||||
if self.tesseract_dep > 1:
|
||||
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
|
||||
xz_group = gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)
|
||||
dist.broadcast(param, src=xz_rank[0],
|
||||
group=xz_group)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> None:
|
||||
dist.all_reduce(grad, group=gpc.get_group(
|
||||
ParallelMode.PARALLEL_2P5D_XZ))
|
||||
grad = grad / self.tesseract_dim / self.tesseract_dep # *
|
||||
grad = grad / self.tesseract_dim # / self.tesseract_dep # *
|
||||
return grad
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
cls_token = AllGatherLast.apply(
|
||||
self.cls_token, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
cls_token = cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
|
||||
pos_embed = AllGatherLast.apply(
|
||||
self.pos_embed, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
x = x + pos_embed
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.pos_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTInputSplitter2p5D(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
batch_size = x.size(0)
|
||||
return _ViT_Split_2p5D.apply(
|
||||
x,
|
||||
batch_size,
|
||||
self.tesseract_dim,
|
||||
self.tesseract_dep,
|
||||
ParallelMode.PARALLEL_2P5D_XZ,
|
||||
)
|
||||
|
@ -10,7 +10,7 @@ from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import Matmul_AB_2p5D, Add_Bias_2p5D, _LayerNorm_2p5D
|
||||
from ._utils import get_tesseract_dim_dep_from_env, assert_tesseract_initialization
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ..base_layer import ParallelLayer
|
||||
|
||||
|
||||
@ -33,7 +33,9 @@ class Linear2p5D(ParallelLayer):
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype=None,
|
||||
skip_bias_add: bool = False
|
||||
skip_bias_add: bool = False,
|
||||
init_weight='torch',
|
||||
init_bias='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -46,7 +48,7 @@ class Linear2p5D(ParallelLayer):
|
||||
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
|
||||
|
||||
# partitioning dimension
|
||||
self.input_size_per_partition = divide(in_features, self.tesseract_dim)
|
||||
@ -69,46 +71,59 @@ class Linear2p5D(ParallelLayer):
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
# initialize parameters
|
||||
self.reset_parameters()
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(init_weight, init_bias)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute(self.weight)
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute(self.bias)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
def reset_parameters(self, init_weight, init_bias) -> None:
|
||||
assert init_weight in ('torch', 'jax', 'zero')
|
||||
assert init_bias in ('torch', 'jax', 'zero')
|
||||
# setting
|
||||
fan_in = self.in_features
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
||||
# init weight
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
with seed(ParallelMode.TENSOR):
|
||||
if init_weight == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -bound, bound)
|
||||
elif init_weight == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -a, a)
|
||||
elif init_weight == 'zero':
|
||||
init.zeros_(self.weight)
|
||||
|
||||
# init bias
|
||||
if self.bias is not None:
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
with seed(ParallelMode.TENSOR):
|
||||
if init_bias == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
elif init_bias == 'jax':
|
||||
init.normal_(self.bias, std=1e-6)
|
||||
elif init_bias == 'zero':
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/dq, n/q, k/q]
|
||||
# output: [m/dq, n/q, h/q]
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
|
||||
|
||||
output = Matmul_AB_2p5D.apply(
|
||||
x,
|
||||
self.weight,
|
||||
self.tesseract_dim,
|
||||
self.tesseract_dep,
|
||||
out_shape,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
ParallelMode.PARALLEL_2P5D_DEP,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
@ -121,11 +136,9 @@ class Linear2p5D(ParallelLayer):
|
||||
None,
|
||||
self.bias,
|
||||
self.hidden_size_per_partition,
|
||||
self.tesseract_dim, self.tesseract_dep,
|
||||
self.tesseract_dim,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
ParallelMode.PARALLEL_2P5D_DEP,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
@ -138,11 +151,9 @@ class Linear2p5D(ParallelLayer):
|
||||
output,
|
||||
self.bias,
|
||||
self.hidden_size_per_partition,
|
||||
self.tesseract_dim, self.tesseract_dep,
|
||||
self.tesseract_dim,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
ParallelMode.PARALLEL_2P5D_DEP,
|
||||
False,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
@ -168,6 +179,7 @@ class LayerNorm2p5D(ParallelLayer):
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
normalized_shape: int,
|
||||
eps: float = 1e-05,
|
||||
@ -184,7 +196,7 @@ class LayerNorm2p5D(ParallelLayer):
|
||||
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
|
||||
|
||||
# partitioning dimension
|
||||
self.partitioned_partition = divide(
|
||||
@ -193,27 +205,19 @@ class LayerNorm2p5D(ParallelLayer):
|
||||
# create parameters
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
|
||||
if self.row_rank == 0:
|
||||
self.gamma = Parameter(torch.ones(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
self.beta = Parameter(torch.zeros(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
else:
|
||||
self.gamma = Parameter(torch.tensor(
|
||||
1.0,
|
||||
requires_grad=True,
|
||||
**factory_kwargs))
|
||||
self.beta = Parameter(torch.tensor(
|
||||
1.0,
|
||||
requires_grad=True,
|
||||
**factory_kwargs))
|
||||
self.gamma = Parameter(torch.ones(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
self.beta = Parameter(torch.zeros(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
set_tensor_parallel_attribute(self.gamma)
|
||||
set_tensor_parallel_attribute(self.beta)
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with torch.no_grad():
|
||||
@ -233,16 +237,12 @@ class LayerNorm2p5D(ParallelLayer):
|
||||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
||||
|
||||
output = _LayerNorm_2p5D.apply(x, E_x, Var_x, self.normalized_shape,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
ParallelMode.PARALLEL_2P5D_DEP)
|
||||
ParallelMode.PARALLEL_2P5D_ROW)
|
||||
bias = Add_Bias_2p5D.apply(
|
||||
None, self.beta, self.partitioned_partition,
|
||||
self.tesseract_dim, self.tesseract_dep,
|
||||
self.tesseract_dim,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
ParallelMode.PARALLEL_2P5D_DEP,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
@ -251,11 +251,9 @@ class LayerNorm2p5D(ParallelLayer):
|
||||
)
|
||||
scale = Add_Bias_2p5D.apply(
|
||||
None, self.gamma, self.partitioned_partition,
|
||||
self.tesseract_dim, self.tesseract_dep,
|
||||
self.tesseract_dim,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
ParallelMode.PARALLEL_2P5D_DEP,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
|
@ -1,21 +1,223 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication import all_gather, reduce_scatter, scatter
|
||||
from colossalai.communication import all_gather, all_reduce, reduce_scatter
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import empty_cache, get_current_device
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
class linear_3d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_dim: int = 0,
|
||||
weight_dim: int = -1,
|
||||
output_dim: int = 0) -> Tensor:
|
||||
assert input_.shape[-1] == weight.shape[0], \
|
||||
'Invalid shapes: input = {}, weight = {}.'.format(input_.shape, weight.shape)
|
||||
|
||||
ctx.use_bias = bias is not None
|
||||
|
||||
input_ = all_gather(input_, input_dim, input_parallel_mode)
|
||||
input_ = torch.cat(input_, dim=input_dim)
|
||||
# weight = all_gather(weight, weight_dim, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
output = reduce_scatter(output, output_dim, output_parallel_mode)
|
||||
|
||||
if bias is not None:
|
||||
# ranks_in_group = gpc.get_ranks_in_group(output_parallel_mode)
|
||||
# src_rank = ranks_in_group[gpc.get_local_rank(input_parallel_mode)]
|
||||
# dist.broadcast(bias,
|
||||
# src=src_rank,
|
||||
# group=gpc.get_group(output_parallel_mode))
|
||||
# bias = all_gather(bias, -1, weight_parallel_mode)
|
||||
output += bias
|
||||
# ctx.src_rank = src_rank
|
||||
|
||||
# ctx.save_for_backward(input_, weight)
|
||||
# output = torch.matmul(input_, weight)
|
||||
# dist.all_reduce(output, group=gpc.get_group(output_parallel_mode))
|
||||
# output += bias
|
||||
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
ctx.input_dim = input_dim
|
||||
ctx.weight_dim = weight_dim
|
||||
ctx.output_dim = output_dim
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
# input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
# dist.all_reduce(input_grad,
|
||||
# group=gpc.get_group(ctx.input_parallel_mode))
|
||||
# weight_grad = torch.matmul(
|
||||
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
|
||||
# output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
# dist.all_reduce(weight_grad,
|
||||
# group=gpc.get_group(ctx.weight_parallel_mode))
|
||||
|
||||
# bias_grad = torch.sum(output_grad,
|
||||
# dim=tuple(
|
||||
# range(len(output_grad.shape))[:-1]))
|
||||
# bias_grad = reduce_scatter(bias_grad, -1,
|
||||
# ctx.weight_parallel_mode)
|
||||
# dist.reduce(bias_grad,
|
||||
# dst=ctx.src_rank,
|
||||
# group=gpc.get_group(ctx.output_parallel_mode))
|
||||
# if gpc.get_local_rank(
|
||||
# ctx.output_parallel_mode) != gpc.get_local_rank(
|
||||
# ctx.input_parallel_mode):
|
||||
# bias_grad = None
|
||||
|
||||
# input_ = all_gather(input_, ctx.input_dim, ctx.input_parallel_mode)
|
||||
# weight = all_gather(weight, ctx.weight_dim,
|
||||
# ctx.weight_parallel_mode)
|
||||
|
||||
output_grad = all_gather(output_grad, ctx.output_dim,
|
||||
ctx.output_parallel_mode)
|
||||
output_grad = torch.cat(output_grad, dim=ctx.output_dim)
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
|
||||
input_grad, input_op = reduce_scatter(input_grad, ctx.input_dim,
|
||||
ctx.input_parallel_mode,
|
||||
async_op=True)
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
|
||||
output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
|
||||
# weight_grad = torch.matmul(
|
||||
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
|
||||
# output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
# weight_grad = reduce_scatter(weight_grad, ctx.weight_dim,
|
||||
# ctx.weight_parallel_mode)
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad,
|
||||
dim=tuple(
|
||||
range(len(output_grad.shape))[:-1]))
|
||||
# bias_grad =all_reduce(bias_grad, ctx.output_parallel_mode)
|
||||
# dist.all_reduce(bias_grad,
|
||||
# group=gpc.get_group(ctx.weight_parallel_mode))
|
||||
weight_grad = torch.cat([weight_grad, torch.unsqueeze(bias_grad, dim=0)])
|
||||
|
||||
weight_grad, weight_op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
|
||||
input_op.wait()
|
||||
weight_op.wait()
|
||||
if ctx.use_bias:
|
||||
bias_grad = weight_grad[-1]
|
||||
weight_grad = weight_grad[:-1]
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
class layer_norm_3d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, input_: Tensor, weight: Tensor, bias: Tensor,
|
||||
normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
# mean = torch.sum(input_, dim=-1)
|
||||
# dist.all_reduce(mean, group=gpc.get_group(output_parallel_mode))
|
||||
# mean /= normalized_shape
|
||||
# mu = input_ - mean
|
||||
# var = torch.sum(torch.pow(mu, 2), dim=-1)
|
||||
# dist.all_reduce(var, group=gpc.get_group(output_parallel_mode))
|
||||
# var /= normalized_shape
|
||||
# std_dev = torch.sqrt(var + eps)
|
||||
# ctx.save_for_backward(input_, mu, std_dev, weight)
|
||||
|
||||
# output = weight * mu / std_dev + bias
|
||||
|
||||
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True),
|
||||
output_parallel_mode) / normalized_shape
|
||||
mu = input_ - mean
|
||||
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True),
|
||||
output_parallel_mode) / normalized_shape
|
||||
sigma = torch.sqrt(var + eps)
|
||||
|
||||
# ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||||
# src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||||
# transforms = torch.stack([weight, bias]).contiguous()
|
||||
# dist.broadcast(transforms,
|
||||
# src=src_rank,
|
||||
# group=gpc.get_group(input_parallel_mode))
|
||||
# transforms = all_gather(transforms, -1, weight_parallel_mode)
|
||||
# weight, bias = transforms[0], transforms[1]
|
||||
|
||||
ctx.save_for_backward(mu, sigma, weight)
|
||||
|
||||
z = mu / sigma
|
||||
output = weight * z + bias
|
||||
|
||||
# ctx.src_rank = src_rank
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
mu, sigma, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
||||
grads = torch.stack([bias_grad, weight_grad]).contiguous()
|
||||
grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1]))
|
||||
grads = all_reduce(grads, ctx.weight_parallel_mode)
|
||||
grads = all_reduce(grads, ctx.input_parallel_mode)
|
||||
bias_grad, weight_grad = grads[0], grads[1]
|
||||
|
||||
# grads = reduce_scatter(grads, -1, ctx.weight_parallel_mode)
|
||||
# dist.reduce(grads,
|
||||
# dst=ctx.src_rank,
|
||||
# group=gpc.get_group(ctx.input_parallel_mode))
|
||||
# if gpc.get_local_rank(
|
||||
# ctx.input_parallel_mode) == gpc.get_local_rank(
|
||||
# ctx.output_parallel_mode):
|
||||
# bias_grad, weight_grad = grads[0], grads[1]
|
||||
# else:
|
||||
# bias_grad, weight_grad = None, None
|
||||
|
||||
dz = output_grad * weight
|
||||
dvar = dz * mu * (-0.5) * sigma**(-3)
|
||||
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
||||
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
|
||||
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
||||
|
||||
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
||||
class Matmul_AB_3D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = AB`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
@ -29,7 +231,6 @@ class Matmul_AB_3D(torch.autograd.Function):
|
||||
# A: [m/q^2, n, k/q]
|
||||
# B: [k/q, h/q^2]
|
||||
# C: [m/q^2, n, h/q]
|
||||
empty_cache()
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
assert A.shape[-1] == B.shape[0], \
|
||||
@ -52,6 +253,7 @@ class Matmul_AB_3D(torch.autograd.Function):
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
@ -72,6 +274,7 @@ class Matmul_ABT_3D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = AB^T`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
@ -85,7 +288,6 @@ class Matmul_ABT_3D(torch.autograd.Function):
|
||||
# A: [m/q^2, n, h/q]
|
||||
# B: [k/q, h/q^2]
|
||||
# C: [m/q^2, n, k/q]
|
||||
empty_cache()
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
||||
@ -105,6 +307,7 @@ class Matmul_ABT_3D(torch.autograd.Function):
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
@ -125,6 +328,7 @@ class Matmul_ATB_3D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = A^TB`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
@ -138,7 +342,6 @@ class Matmul_ATB_3D(torch.autograd.Function):
|
||||
# A: [m/q^2, n, k/q]
|
||||
# B: [m/q^2, n, h/q]
|
||||
# C: [k/q, h/q^2]
|
||||
empty_cache()
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
||||
@ -160,6 +363,7 @@ class Matmul_ATB_3D(torch.autograd.Function):
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
@ -180,6 +384,7 @@ class Add_3D(torch.autograd.Function):
|
||||
"""Matrix add bias: :math:`C = A + b`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
@ -206,6 +411,7 @@ class Add_3D(torch.autograd.Function):
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
# output_grad: [m/q^2, n, h/q]
|
||||
with torch.no_grad():
|
||||
@ -217,8 +423,8 @@ class Add_3D(torch.autograd.Function):
|
||||
dst=ctx.src_rank,
|
||||
group=gpc.get_group(ctx.A_group_parallel_mode))
|
||||
if gpc.get_local_rank(
|
||||
ctx.A_group_parallel_mode) != gpc.get_local_rank(
|
||||
ctx.C_group_parallel_mode):
|
||||
ctx.A_group_parallel_mode) != gpc.get_local_rank(
|
||||
ctx.C_group_parallel_mode):
|
||||
bias_grad = None
|
||||
return output_grad, bias_grad, None, None, None, None
|
||||
|
||||
@ -227,6 +433,7 @@ class Mul_3D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = A * b`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
@ -243,7 +450,7 @@ class Mul_3D(torch.autograd.Function):
|
||||
# [h/q]
|
||||
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
|
||||
|
||||
empty_cache()
|
||||
# empty_cache()
|
||||
ctx.save_for_backward(input_, bias_temp)
|
||||
|
||||
out = torch.mul(input_, bias_temp)
|
||||
@ -257,6 +464,7 @@ class Mul_3D(torch.autograd.Function):
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
# output_grad: [m/q^2, n, h/q]
|
||||
with torch.no_grad():
|
||||
@ -272,8 +480,8 @@ class Mul_3D(torch.autograd.Function):
|
||||
dst=ctx.src_rank,
|
||||
group=gpc.get_group(ctx.A_group_parallel_mode))
|
||||
if gpc.get_local_rank(
|
||||
ctx.A_group_parallel_mode) != gpc.get_local_rank(
|
||||
ctx.C_group_parallel_mode):
|
||||
ctx.A_group_parallel_mode) != gpc.get_local_rank(
|
||||
ctx.C_group_parallel_mode):
|
||||
bias_grad = None
|
||||
return input_grad, bias_grad, None, None, None, None
|
||||
|
||||
@ -282,6 +490,7 @@ class Sum_3D(torch.autograd.Function):
|
||||
"""Compute the sum of input tensors
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
input_: Tensor,
|
||||
dim: int,
|
||||
@ -299,6 +508,7 @@ class Sum_3D(torch.autograd.Function):
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
with torch.no_grad():
|
||||
output_grad = output_grad.contiguous()
|
||||
@ -315,35 +525,39 @@ class Reduce_3D(torch.autograd.Function):
|
||||
"""Reduce input tensors
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, input_: Tensor, depth: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
|
||||
return input_.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
return output_grad, None, None
|
||||
|
||||
|
||||
class Slice_3D(torch.autograd.Function):
|
||||
"""Slice input tensor
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
|
||||
# class Slice_3D(torch.autograd.Function):
|
||||
# """Slice input tensor
|
||||
# """
|
||||
# @staticmethod
|
||||
# @custom_fwd(cast_inputs=torch.float16)
|
||||
# def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
|
||||
# parallel_mode: ParallelMode) -> Tensor:
|
||||
# rank = gpc.get_local_rank(parallel_mode)
|
||||
# out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
|
||||
|
||||
ctx.depth = depth
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.dim = dim
|
||||
ctx.input_shape = input_.shape
|
||||
# ctx.depth = depth
|
||||
# ctx.parallel_mode = parallel_mode
|
||||
# ctx.dim = dim
|
||||
# ctx.input_shape = input_.shape
|
||||
|
||||
return out
|
||||
# return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
with torch.no_grad():
|
||||
input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
|
||||
input_grad.reshape(ctx.input_shape)
|
||||
return input_grad, None, None, None
|
||||
# @staticmethod
|
||||
# @custom_bwd
|
||||
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
# with torch.no_grad():
|
||||
# input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
|
||||
# input_grad.reshape(ctx.input_shape)
|
||||
# return input_grad, None, None, None
|
||||
|
@ -3,7 +3,8 @@
|
||||
|
||||
import os
|
||||
|
||||
from colossalai.constants import DEPTH_3D
|
||||
from colossalai.constants import (DEPTH_3D, INPUT_GROUP_3D, OUTPUT_GROUP_3D,
|
||||
WEIGHT_GROUP_3D)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch import Tensor
|
||||
@ -23,6 +24,10 @@ def get_depth_from_env() -> int:
|
||||
)
|
||||
|
||||
|
||||
def get_parallel_mode_from_env(group):
|
||||
return getattr(ParallelMode, os.environ[group])
|
||||
|
||||
|
||||
def get_last_group(a, b):
|
||||
mapping = {
|
||||
ParallelMode.PARALLEL_3D_INPUT: 'A',
|
||||
@ -41,6 +46,11 @@ def get_last_group(a, b):
|
||||
return ParallelMode.PARALLEL_3D_OUTPUT
|
||||
|
||||
|
||||
def swap_in_out_group():
|
||||
os.environ[INPUT_GROUP_3D], os.environ[OUTPUT_GROUP_3D] = \
|
||||
os.environ[OUTPUT_GROUP_3D], os.environ[INPUT_GROUP_3D]
|
||||
|
||||
|
||||
def dbg_check_shape(tensor: Tensor, shape: tuple):
|
||||
rank = gpc.get_global_rank()
|
||||
if rank == 0:
|
||||
|
@ -1,17 +1,20 @@
|
||||
import math
|
||||
from typing import Tuple
|
||||
import os
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
|
||||
WEIGHT_GROUP_3D)
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.nn.init import init_bias_, init_weight_
|
||||
from colossalai.utils import checkpoint, get_current_device
|
||||
from torch import Tensor, dtype, nn
|
||||
|
||||
from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute
|
||||
from ..vanilla_vision_transformer.layers import to_2tuple
|
||||
from ._utils import get_depth_from_env
|
||||
from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute_by_size, to_2tuple
|
||||
from ._utils import get_depth_from_env, get_parallel_mode_from_env, get_last_group
|
||||
from .layers import Linear3D
|
||||
|
||||
|
||||
@ -32,34 +35,42 @@ class ViTPatchEmbedding3D(nn.Module):
|
||||
:param flatten: whether to flatten output tensor, defaults to True
|
||||
:type flatten: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
drop_prob: float,
|
||||
flatten: bool = True):
|
||||
flatten: bool = True,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
|
||||
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
|
||||
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
self.weight_parallel_mode)
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.in_chans = in_chans
|
||||
self.embed_size = embed_size
|
||||
self.embed_size_per_partition = divide(self.embed_size, self.depth)
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.init_weight = 'torch'
|
||||
self.init_bias = 'torch'
|
||||
if init_method == 'jax':
|
||||
self.init_weight = 'jax_embed'
|
||||
self.init_bias = 'zero'
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.proj = nn.Conv2d(in_chans,
|
||||
self.embed_size_per_partition,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size)
|
||||
self.proj = nn.Conv2d(self.in_chans,
|
||||
self.embed_size_per_partition,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size)
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros(1, 1, self.embed_size_per_partition))
|
||||
@ -68,23 +79,26 @@ class ViTPatchEmbedding3D(nn.Module):
|
||||
self.embed_size_per_partition))
|
||||
self.pos_drop = nn.Dropout(drop_prob)
|
||||
|
||||
self._sync_parameters()
|
||||
self.proj.weight.register_hook(self._sync_grad_hook)
|
||||
self.proj.bias.register_hook(self._sync_grad_hook)
|
||||
self.cls_token.register_hook(self._sync_grad_hook)
|
||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||
self._set_tensor_parallel_attribute()
|
||||
self.reset_parameters(self.init_weight, self.init_bias)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
set_tensor_parallel_attribute(self.proj.weight)
|
||||
set_tensor_parallel_attribute(self.proj.bias)
|
||||
set_tensor_parallel_attribute(self.cls_token)
|
||||
set_tensor_parallel_attribute(self.pos_embed)
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_size(self.proj.weight, self.in_chans * self.embed_size * self.num_patches)
|
||||
set_tensor_parallel_attribute_by_size(self.proj.bias, self.embed_size)
|
||||
set_tensor_parallel_attribute_by_size(self.cls_token, 1 * 1 * self.embed_size)
|
||||
set_tensor_parallel_attribute_by_size(self.pos_embed, 1 * (self.num_patches + 1) * self.embed_size)
|
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
|
||||
return self.input_parallel_mode, self.weight_parallel_mode
|
||||
def reset_parameters(self, init_weight, init_bias):
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.proj.weight)
|
||||
# std = math.sqrt(1.0 / fan_in)
|
||||
# nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
|
||||
# nn.init.zeros_(self.proj.bias)
|
||||
if init_weight != 'torch':
|
||||
init_weight_(self.proj.weight, fan_in, init_method=init_weight)
|
||||
init_bias_(self.pos_embed, fan_in, init_method=init_weight)
|
||||
if init_bias != 'torch':
|
||||
init_bias_(self.proj.bias, fan_in, init_method=init_bias)
|
||||
|
||||
def _sync_parameters(self):
|
||||
self.to(get_current_device())
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
dist.broadcast(self.proj.weight,
|
||||
@ -100,10 +114,11 @@ class ViTPatchEmbedding3D(nn.Module):
|
||||
dist.broadcast(self.proj.bias,
|
||||
src=input_src_rank,
|
||||
group=gpc.get_group(self.input_parallel_mode))
|
||||
set_tensor_parallel_attribute(self.proj.weight)
|
||||
set_tensor_parallel_attribute(self.proj.bias)
|
||||
set_tensor_parallel_attribute(self.cls_token)
|
||||
set_tensor_parallel_attribute(self.pos_embed)
|
||||
|
||||
self.proj.weight.register_hook(self._sync_grad_hook)
|
||||
self.proj.bias.register_hook(self._sync_grad_hook)
|
||||
self.cls_token.register_hook(self._sync_grad_hook)
|
||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> None:
|
||||
dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode))
|
||||
@ -111,6 +126,12 @@ class ViTPatchEmbedding3D(nn.Module):
|
||||
return grad
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# split a partition from inputs
|
||||
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
|
||||
self.weight_parallel_mode)].contiguous()
|
||||
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
|
||||
self.input_parallel_mode)].contiguous()
|
||||
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
@ -118,12 +139,6 @@ class ViTPatchEmbedding3D(nn.Module):
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
# split a partition from embedded states
|
||||
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
|
||||
self.weight_parallel_mode)].contiguous()
|
||||
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
|
||||
self.input_parallel_mode)].contiguous()
|
||||
|
||||
# add cls token & pos embedding
|
||||
# [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q]
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
@ -158,6 +173,7 @@ class ViTSelfAttention3D(nn.Module):
|
||||
:param bias: whether to add bias, defaults to True
|
||||
:type bias: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
@ -165,41 +181,52 @@ class ViTSelfAttention3D(nn.Module):
|
||||
hidden_dropout_prob: float,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False):
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
|
||||
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
|
||||
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
|
||||
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
# self.weight_parallel_mode)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = divide(num_attention_heads, self.depth)
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.checkpoint = checkpoint
|
||||
self.init_weight = 'torch'
|
||||
self.init_bias = 'torch'
|
||||
if init_method == 'jax':
|
||||
self.init_weight = 'jax'
|
||||
self.init_bias = 'zero'
|
||||
|
||||
self.query_key_value = Linear3D(self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias)
|
||||
bias=bias,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias)
|
||||
self.attention_dropout = nn.Dropout(attention_probs_dropout_prob)
|
||||
self.dense = Linear3D(self.hidden_size,
|
||||
self.hidden_size,
|
||||
self.output_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
# self.output_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias)
|
||||
bias=bias,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
|
||||
return self.input_parallel_mode, self.weight_parallel_mode
|
||||
# def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
|
||||
# return self.input_parallel_mode, self.weight_parallel_mode
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(query_key_value,
|
||||
@ -259,6 +286,7 @@ class ViTMLP3D(nn.Module):
|
||||
:param bias: whether to add bias, defaults to True
|
||||
:type bias: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
mlp_ratio: int,
|
||||
@ -266,33 +294,41 @@ class ViTMLP3D(nn.Module):
|
||||
hidden_act: str = 'gelu',
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False):
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
|
||||
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
|
||||
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
|
||||
# self.depth = get_depth_from_env()
|
||||
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
# self.weight_parallel_mode)
|
||||
self.hidden_size = hidden_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.checkpoint = checkpoint
|
||||
self.init_weight = init_method
|
||||
self.init_bias = init_method
|
||||
|
||||
self.dense_1 = Linear3D(self.hidden_size,
|
||||
self.mlp_ratio * self.hidden_size,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias)
|
||||
bias=bias,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias)
|
||||
self.activation_func = ACT2FN[hidden_act]
|
||||
self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size,
|
||||
self.hidden_size,
|
||||
self.output_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
# self.output_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias)
|
||||
bias=bias,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
|
||||
return self.input_parallel_mode, self.weight_parallel_mode
|
||||
# def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
|
||||
# return self.input_parallel_mode, self.weight_parallel_mode
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
@ -331,37 +367,46 @@ class ViTHead3D(nn.Module):
|
||||
:param bias: whether to add bias, defaults to True
|
||||
:type bias: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True):
|
||||
bias: bool = True,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
|
||||
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
|
||||
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
|
||||
# self.depth = get_depth_from_env()
|
||||
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
# self.weight_parallel_mode)
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
out_features = math.ceil(self.num_classes /
|
||||
(self.depth**2)) * (self.depth**2)
|
||||
self.num_classes_per_partition = divide(self.num_classes, self.depth)
|
||||
self.linear = Linear3D(self.in_features,
|
||||
out_features,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias)
|
||||
# out_features = math.ceil(self.num_classes /
|
||||
# (self.depth**2)) * (self.depth**2)
|
||||
# self.num_classes_per_partition = divide(self.num_classes, self.depth)
|
||||
self.init_weight = 'torch'
|
||||
self.init_bias = 'torch'
|
||||
if init_method == 'jax':
|
||||
self.init_weight = 'zero'
|
||||
self.init_bias = 'zero'
|
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
|
||||
return self.linear.groups_for_next_layer()
|
||||
self.linear = Linear3D(self.in_features,
|
||||
self.num_classes,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# [b/q^2, s, h/q] --> [b/q^2, h/q]
|
||||
x = x[:, 0]
|
||||
# [b/q^2, h/q] --> [b/q^2, c/q]
|
||||
x = self.linear(x)
|
||||
return x[:, :self.num_classes_per_partition]
|
||||
# return x[:, :self.num_classes_per_partition]
|
||||
return x
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, num_classes={}'.format(self.in_features,
|
||||
|
@ -2,19 +2,28 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
|
||||
WEIGHT_GROUP_3D)
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.init import init_bias_, init_weight_
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import Tensor, dtype
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import init as init
|
||||
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute
|
||||
from ._operation import Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D
|
||||
from ._utils import get_depth_from_env, get_last_group
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_size
|
||||
from ._operation import (Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D, layer_norm_3d,
|
||||
linear_3d)
|
||||
from ._utils import (get_depth_from_env, get_last_group,
|
||||
get_parallel_mode_from_env, swap_in_out_group)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@ -22,20 +31,19 @@ class LayerNorm3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
# input_parallel_mode: ParallelMode,
|
||||
# weight_parallel_mode: ParallelMode,
|
||||
eps: float = 1e-12,
|
||||
dtype: dtype = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_parallel_mode = input_parallel_mode
|
||||
self.weight_parallel_mode = weight_parallel_mode
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
self.weight_parallel_mode)
|
||||
self.depth = get_depth_from_env()
|
||||
self.normalized_shape = normalized_shape
|
||||
self.normalized_shape_per_partition = divide(normalized_shape,
|
||||
self.depth**2)
|
||||
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.ones(self.normalized_shape_per_partition,
|
||||
@ -49,37 +57,40 @@ class LayerNorm3D(nn.Module):
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute(self.weight)
|
||||
set_tensor_parallel_attribute(self.bias)
|
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
|
||||
return self.input_parallel_mode, self.weight_parallel_mode
|
||||
set_tensor_parallel_attribute_by_size(self.weight, self.normalized_shape)
|
||||
set_tensor_parallel_attribute_by_size(self.bias, self.normalized_shape)
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.zeros_(self.bias)
|
||||
nn.init.ones_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
init.ones_(self.weight)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
'''x = weight * (x - mean) / sqrt(var + eps) + bias'''
|
||||
# input: [m/q^2, n, h/q]
|
||||
# [m/q^2, n, 1]
|
||||
mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
|
||||
True) / self.normalized_shape
|
||||
# [m/q^2, n, 1]
|
||||
var = (input_ - mean).pow(2)
|
||||
var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
|
||||
True) / self.normalized_shape
|
||||
# '''x = weight * (x - mean) / sqrt(var + eps) + bias'''
|
||||
# # input: [m/q^2, n, h/q]
|
||||
# # [m/q^2, n, 1]
|
||||
# mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
|
||||
# True) / self.normalized_shape
|
||||
# # [m/q^2, n, 1]
|
||||
# var = (input_ - mean).pow(2)
|
||||
# var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
|
||||
# True) / self.normalized_shape
|
||||
|
||||
output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
|
||||
output = Mul_3D.apply(output, self.weight, self.depth,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
output = Add_3D.apply(output, self.bias, self.depth,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
return output
|
||||
# output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
|
||||
# output = Mul_3D.apply(output, self.weight, self.depth,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
# self.output_parallel_mode)
|
||||
# output = Add_3D.apply(output, self.bias, self.depth,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
# self.output_parallel_mode)
|
||||
# return output
|
||||
return layer_norm_3d.apply(input_, self.weight, self.bias,
|
||||
self.normalized_shape,
|
||||
self.variance_epsilon,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
|
||||
def extra_repr(self):
|
||||
return '{}, eps={}'.format(self.normalized_shape,
|
||||
@ -88,33 +99,36 @@ class LayerNorm3D(nn.Module):
|
||||
|
||||
@LAYERS.register_module
|
||||
class Linear3D(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
# input_parallel_mode: ParallelMode,
|
||||
# weight_parallel_mode: ParallelMode,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
init_weight: str = 'torch',
|
||||
init_bias: str = 'torch'):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.input_parallel_mode = input_parallel_mode
|
||||
self.weight_parallel_mode = weight_parallel_mode
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
self.weight_parallel_mode)
|
||||
self.with_bias = bias
|
||||
# self.with_bias = bias
|
||||
self.depth = get_depth_from_env()
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
self.out_features_per_partition = divide(out_features, self.depth**2)
|
||||
self.out_features_per_partition = divide(out_features, self.depth)
|
||||
|
||||
# [k/q, h/q^2]
|
||||
# [k/q, h/q]
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.in_features_per_partition,
|
||||
self.out_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
|
||||
# [h/q^2]
|
||||
# [h/q]
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.out_features_per_partition,
|
||||
@ -123,49 +137,54 @@ class Linear3D(nn.Module):
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
self.reset_parameters(init_weight, init_bias)
|
||||
self._set_tensor_parallel_attributes()
|
||||
swap_in_out_group()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute(self.weight)
|
||||
set_tensor_parallel_attribute_by_size(self.weight, self.in_features * self.out_features)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute(self.bias)
|
||||
set_tensor_parallel_attribute_by_size(self.bias, self.out_features)
|
||||
|
||||
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
|
||||
return self.output_parallel_mode, self.weight_parallel_mode
|
||||
|
||||
def reset_parameters(self):
|
||||
def reset_parameters(self, init_weight, init_bias) -> None:
|
||||
# setting
|
||||
fan_in = self.in_features
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
|
||||
# init weight
|
||||
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
with seed(ParallelMode.TENSOR):
|
||||
nn.init.uniform_(self.weight, -bound, bound)
|
||||
|
||||
init_weight_(self.weight, fan_in, fan_out, init_method=init_weight)
|
||||
dist.broadcast(self.weight,
|
||||
src=weight_src_rank,
|
||||
group=gpc.get_group(self.weight_parallel_mode))
|
||||
# init bias
|
||||
if self.with_bias:
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
with seed(ParallelMode.TENSOR):
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
if self.bias is not None:
|
||||
init_bias_(self.bias, fan_in, init_method=init_bias)
|
||||
dist.broadcast(self.bias,
|
||||
src=weight_src_rank,
|
||||
group=gpc.get_group(self.weight_parallel_mode))
|
||||
dist.broadcast(self.bias,
|
||||
src=output_src_rank,
|
||||
group=gpc.get_group(self.output_parallel_mode))
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# input: [m/q^2, n, k/q]
|
||||
# output: [m/q^2, n, h/q]
|
||||
output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
# # input: [m/q^2, n, k/q]
|
||||
# # output: [m/q^2, n, h/q]
|
||||
# output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
# self.output_parallel_mode)
|
||||
|
||||
if self.with_bias:
|
||||
output = Add_3D.apply(output, self.bias, self.depth,
|
||||
self.output_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.input_parallel_mode)
|
||||
return output
|
||||
# if self.bias is not None:
|
||||
# output = Add_3D.apply(output, self.bias, self.depth,
|
||||
# self.output_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
# self.input_parallel_mode)
|
||||
# return output
|
||||
return linear_3d.apply(input_, self.weight, self.bias,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
|
@ -1,3 +0,0 @@
|
||||
from .layers import ViTBlock
|
||||
|
||||
__all__ = ['ViTBlock']
|
@ -1,59 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from colossalai.builder import build_layer
|
||||
from colossalai.registry import LAYERS
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTBlock(nn.Module):
|
||||
"""Vision Transformer block
|
||||
|
||||
:param attention_cfg: config of attention layer
|
||||
:type attention_cfg: dict
|
||||
:param droppath_cfg: config of drop path
|
||||
:type droppath_cfg: dict
|
||||
:param mlp_cfg: config of MLP layer
|
||||
:type mlp_cfg: dict
|
||||
:param norm_cfg: config of normlization layer
|
||||
:type norm_cfg: dict
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
attention_cfg: dict,
|
||||
droppath_cfg: dict,
|
||||
mlp_cfg: dict,
|
||||
norm_cfg: dict,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = build_layer(norm_cfg)
|
||||
self.attn = build_layer(attention_cfg)
|
||||
self.drop_path = build_layer(
|
||||
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
|
||||
self.norm2 = build_layer(norm_cfg)
|
||||
self.mlp = build_layer(mlp_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
# x_ = x
|
||||
# x_ = self.norm1(x_)
|
||||
# if self.checkpoint:
|
||||
# x_ = checkpoint(self.attn, x_)
|
||||
# else:
|
||||
# x_ = self.attn(x_)
|
||||
# x_ = self.drop_path(x_)
|
||||
# x = x + x_
|
||||
#
|
||||
# x_ = x
|
||||
# x_ = self.norm2(x_)
|
||||
# if self.checkpoint:
|
||||
# x_ = checkpoint(self.mlp, x_)
|
||||
# else:
|
||||
# x_ = self.mlp(x_)
|
||||
# x_ = self.drop_path(x_)
|
||||
# x = x + x_
|
||||
return x
|
@ -1,5 +0,0 @@
|
||||
from .basic_block import ResNetBasicBlock
|
||||
from .bottleneck import ResNetBottleneck
|
||||
from .reslayer import ResLayer
|
||||
|
||||
__all__ = ['ResLayer', 'ResNetBottleneck', 'ResNetBasicBlock']
|
@ -1,64 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional, Callable
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from .conv import conv3x3
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ResNetBasicBlock(nn.Module):
|
||||
"""Basic ResNet block
|
||||
"""
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
"Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
@ -1,69 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional, Callable
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from .conv import conv3x3, conv1x1
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ResNetBottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion: int = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
@ -1,15 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
@ -1,63 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from .conv import conv1x1
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ResLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
block_type: str,
|
||||
norm_layer_type: str,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
blocks: int,
|
||||
groups: int,
|
||||
base_width: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
dilate: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.block = LAYERS.get_module(block_type)
|
||||
self.norm_layer = LAYERS.get_module(norm_layer_type)
|
||||
self.inplanes = inplanes
|
||||
self.planes = planes
|
||||
self.blocks = blocks
|
||||
self.groups = groups
|
||||
self.dilation = dilation
|
||||
self.base_width = base_width
|
||||
self.dilate = dilate
|
||||
self.stride = stride
|
||||
self.layer = self._make_layer()
|
||||
|
||||
def _make_layer(self):
|
||||
norm_layer = self.norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if self.dilate:
|
||||
self.dilation *= self.stride
|
||||
self.stride = 1
|
||||
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
|
||||
norm_layer(self.planes * self.block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = self.planes * self.block.expansion
|
||||
for _ in range(1, self.blocks):
|
||||
layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
@ -1,7 +0,0 @@
|
||||
from .layers import (VanillaViTBlock, VanillaViTMLP, VanillaViTPatchEmbedding,
|
||||
VanillaViTAttention, VanillaViTDropPath, VanillaViTHead)
|
||||
|
||||
__all__ = [
|
||||
'VanillaViTBlock', 'VanillaViTMLP', 'VanillaViTPatchEmbedding',
|
||||
'VanillaViTAttention', 'VanillaViTDropPath', 'VanillaViTHead'
|
||||
]
|
@ -1,4 +1,3 @@
|
||||
from .base_loss import BaseLoss
|
||||
from .cross_entropy_2d import CrossEntropyLoss2D
|
||||
from .cross_entropy_2p5d import CrossEntropyLoss2p5D
|
||||
from .cross_entropy_3d import CrossEntropyLoss3D
|
||||
|
@ -1,13 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseLoss(ABC):
|
||||
"""Absctract loss class
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def calc_loss(self, *args, **kwargs):
|
||||
pass
|
@ -1,120 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_1d._utils import vocab_range_from_per_partition_vocab_size
|
||||
|
||||
|
||||
class _VocabParallelCrossEntropy_1D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_parallel_logits, target):
|
||||
# Maximum value along vocab dimension across all GPUs.
|
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
# Subtract the maximum value.
|
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
|
||||
# Get the partition's vocab indecies
|
||||
partition_vocab_size = vocab_parallel_logits.size()[-1]
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
vocab_start_index, vocab_end_index = vocab_range_from_per_partition_vocab_size(
|
||||
partition_vocab_size, rank, world_size)
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
|
||||
masked_target = target.clone() - vocab_start_index
|
||||
masked_target[target_mask] = 0
|
||||
|
||||
# Get predicted-logits = logits[target].
|
||||
# For Simplicity, we convert logits to a 2-D tensor with size
|
||||
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
|
||||
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
|
||||
device=logits_2d.device)
|
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
|
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
||||
predicted_logits = predicted_logits_1d.view_as(target)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
exp_logits = vocab_parallel_logits
|
||||
torch.exp(vocab_parallel_logits, out=exp_logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=-1)
|
||||
torch.distributed.all_reduce(sum_exp_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
|
||||
# Store softmax, target-mask and masked-target for backward pass.
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# Retreive tensors from the forward path.
|
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
# All the inputs have softmax as thier gradient.
|
||||
grad_input = softmax
|
||||
# For simplicity, work with the 2D gradient.
|
||||
partition_vocab_size = softmax.size()[-1]
|
||||
grad_2d = grad_input.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
|
||||
device=grad_2d.device)
|
||||
grad_2d[arange_1d, masked_target_1d] -= (
|
||||
1.0 - target_mask.view(-1).float())
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
class LmLoss1D(_Loss):
|
||||
|
||||
def forward(self, lm_logits, lm_labels, loss_mask):
|
||||
lm_loss = _VocabParallelCrossEntropy_1D.apply(lm_logits, lm_labels)
|
||||
lm_loss = torch.sum(
|
||||
lm_loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
|
||||
return lm_loss
|
||||
|
||||
|
||||
class SopLoss1D(_Loss):
|
||||
|
||||
def forward(self, sop_logits, sentence_order):
|
||||
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
|
||||
sentence_order.view(-1),
|
||||
ignore_index=-1)
|
||||
return sop_loss
|
||||
|
||||
|
||||
class BERTDualHeadLoss(_Loss):
|
||||
|
||||
def __init__(self):
|
||||
self.lm_loss = LmLoss1D()
|
||||
self.sop_loss = SopLoss1D()
|
||||
|
||||
def forward(self, lm_logits, sop_logits, lm_labels, loss_mask, sentence_order):
|
||||
lm_loss = self.lm_loss(lm_logits, lm_labels, loss_mask)
|
||||
sop_loss = self.sop_loss(sop_logits, sentence_order)
|
||||
return lm_loss + sop_loss
|
@ -7,18 +7,18 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||
from colossalai.registry import LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
|
||||
### Modified based on megatron.mpu.cross_entropy ###
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, logits, targets):
|
||||
# logits: [b/q, h/q]
|
||||
# labels: [b/q]
|
||||
# loss: [b/q]
|
||||
# vocab_parallel_logits: [b/q, s, v/q]
|
||||
# target: [b/q, s]
|
||||
|
||||
logits_max = torch.max(logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(
|
||||
logits_max,
|
||||
@ -58,6 +58,7 @@ class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad):
|
||||
# Retreive tensors from the forward path.
|
||||
softmax, target_mask, masked_target = ctx.saved_tensors
|
||||
@ -91,12 +92,14 @@ class _ReduceByColumn(torch.autograd.Function):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(
|
||||
ParallelMode.PARALLEL_2D_COL))
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
|
@ -1,32 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.communication import all_gather
|
||||
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
|
||||
WEIGHT_GROUP_3D)
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_3d._operation import Reduce_3D
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_last_group, get_depth_from_env
|
||||
from colossalai.nn.layer.parallel_3d._utils import (get_depth_from_env,
|
||||
get_last_group,
|
||||
get_parallel_mode_from_env)
|
||||
from colossalai.registry import LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def accuracy_3d(output, target, input_parallel_mode, weight_parallel_mode):
|
||||
depth = get_depth_from_env()
|
||||
output_parallel_mode = get_last_group(input_parallel_mode,
|
||||
weight_parallel_mode)
|
||||
j = gpc.get_local_rank(input_parallel_mode)
|
||||
i = gpc.get_local_rank(weight_parallel_mode)
|
||||
target = torch.chunk(target, depth, dim=0)[i]
|
||||
target = torch.chunk(target, depth, dim=0)[j]
|
||||
output = all_gather(output, -1, output_parallel_mode)
|
||||
prediction = torch.argmax(output, dim=-1)
|
||||
correct = torch.sum(prediction == target)
|
||||
dist.all_reduce(correct, group=gpc.get_group(input_parallel_mode))
|
||||
dist.all_reduce(correct, group=gpc.get_group(weight_parallel_mode))
|
||||
return correct.item()
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
class _ParallelCrossEntropyLossFunction_3D(torch.autograd.Function):
|
||||
@ -112,16 +100,18 @@ class CrossEntropyLoss3D(_Loss):
|
||||
:param reduction: whether to average the loss, defaults to True
|
||||
:type reduction: bool, optional
|
||||
"""
|
||||
def __init__(self,
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
reduction=True):
|
||||
def __init__(
|
||||
self,
|
||||
# input_parallel_mode,
|
||||
# weight_parallel_mode,
|
||||
reduction=True,
|
||||
label_smoothing=0.0):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = input_parallel_mode
|
||||
self.weight_parallel_mode = weight_parallel_mode
|
||||
self.output_parallel_mode = get_last_group(input_parallel_mode,
|
||||
weight_parallel_mode)
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
self.weight_parallel_mode)
|
||||
self.input_rank = gpc.get_local_rank(self.input_parallel_mode)
|
||||
self.weight_rank = gpc.get_local_rank(self.weight_parallel_mode)
|
||||
self.reduction_mean = reduction
|
||||
@ -141,53 +131,53 @@ class CrossEntropyLoss3D(_Loss):
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class LabelSmoothingCrossEntropy3D(_Loss):
|
||||
"""
|
||||
NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy
|
||||
# @LOSSES.register_module
|
||||
# class LabelSmoothingCrossEntropy3D(_Loss):
|
||||
# """
|
||||
# NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy
|
||||
|
||||
:param input_parallel_mode: parallel mode for input tensor
|
||||
:type input_parallel_mode: ParallelMode
|
||||
:param weight_parallel_mode: parallel mode for weight
|
||||
:type weight_parallel_mode: ParallelMode
|
||||
:param smoothing: label smoothing value, defaults to 0.1
|
||||
:type smoothing: float
|
||||
:param reduction: whether to average the loss, defaults to True
|
||||
:type reduction: bool, optional
|
||||
"""
|
||||
def __init__(self,
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
smoothing=0.1,
|
||||
reduction=True):
|
||||
super().__init__()
|
||||
assert smoothing < 1.0
|
||||
self.smoothing = smoothing
|
||||
self.confidence = 1. - smoothing
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = input_parallel_mode
|
||||
self.weight_parallel_mode = weight_parallel_mode
|
||||
self.output_parallel_mode = get_last_group(input_parallel_mode,
|
||||
weight_parallel_mode)
|
||||
self.reduction_mean = reduction
|
||||
# :param input_parallel_mode: parallel mode for input tensor
|
||||
# :type input_parallel_mode: ParallelMode
|
||||
# :param weight_parallel_mode: parallel mode for weight
|
||||
# :type weight_parallel_mode: ParallelMode
|
||||
# :param smoothing: label smoothing value, defaults to 0.1
|
||||
# :type smoothing: float
|
||||
# :param reduction: whether to average the loss, defaults to True
|
||||
# :type reduction: bool, optional
|
||||
# """
|
||||
# def __init__(self,
|
||||
# input_parallel_mode,
|
||||
# weight_parallel_mode,
|
||||
# smoothing=0.1,
|
||||
# reduction=True):
|
||||
# super().__init__()
|
||||
# assert smoothing < 1.0
|
||||
# self.smoothing = smoothing
|
||||
# self.confidence = 1. - smoothing
|
||||
# self.depth = get_depth_from_env()
|
||||
# self.input_parallel_mode = input_parallel_mode
|
||||
# self.weight_parallel_mode = weight_parallel_mode
|
||||
# self.output_parallel_mode = get_last_group(input_parallel_mode,
|
||||
# weight_parallel_mode)
|
||||
# self.reduction_mean = reduction
|
||||
|
||||
def forward(self, logits, targets):
|
||||
# split label partition from the entire batch
|
||||
j = gpc.get_local_rank(self.input_parallel_mode)
|
||||
i = gpc.get_local_rank(self.weight_parallel_mode)
|
||||
targets = torch.chunk(targets, self.depth, dim=0)[i]
|
||||
targets = torch.chunk(targets, self.depth, dim=0)[j]
|
||||
exp_logits = torch.exp(logits)
|
||||
sum_exp_logits = Sum3D.apply(exp_logits, -1, depth,
|
||||
self.output_parallel_mode, False)
|
||||
log_probs = torch.log(sum_exp_logits) - logits
|
||||
nll_loss = _ParallelCrossEntropyLossFunction_3D.apply(
|
||||
logits, targets, self.depth, self.output_parallel_mode)
|
||||
smooth_loss = -log_probs.mean(dim=-1)
|
||||
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
|
||||
loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
|
||||
loss /= batch_size
|
||||
return loss
|
||||
# def forward(self, logits, targets):
|
||||
# # split label partition from the entire batch
|
||||
# j = gpc.get_local_rank(self.input_parallel_mode)
|
||||
# i = gpc.get_local_rank(self.weight_parallel_mode)
|
||||
# targets = torch.chunk(targets, self.depth, dim=0)[i]
|
||||
# targets = torch.chunk(targets, self.depth, dim=0)[j]
|
||||
# exp_logits = torch.exp(logits)
|
||||
# sum_exp_logits = Sum3D.apply(exp_logits, -1, depth,
|
||||
# self.output_parallel_mode, False)
|
||||
# log_probs = torch.log(sum_exp_logits) - logits
|
||||
# nll_loss = _ParallelCrossEntropyLossFunction_3D.apply(
|
||||
# logits, targets, self.depth, self.output_parallel_mode)
|
||||
# smooth_loss = -log_probs.mean(dim=-1)
|
||||
# loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
||||
# if self.reduction_mean:
|
||||
# loss = loss.sum()
|
||||
# loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
|
||||
# loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
|
||||
# loss /= batch_size
|
||||
# return loss
|
||||
|
@ -48,8 +48,10 @@ class DelayerScheduler(_LRScheduler):
|
||||
if self.finished:
|
||||
if epoch is None:
|
||||
self.after_scheduler.step(None)
|
||||
self._last_lr = self.after_scheduler.get_last_lr()
|
||||
else:
|
||||
self.after_scheduler.step(epoch - self.delay_epochs)
|
||||
self._last_lr = self.after_scheduler.get_last_lr()
|
||||
else:
|
||||
return super(DelayerScheduler, self).step(epoch)
|
||||
|
||||
@ -66,6 +68,7 @@ class WarmupScheduler(_LRScheduler):
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
||||
self.warmup_epochs = int(warmup_epochs)
|
||||
self.after_scheduler = after_scheduler
|
||||
@ -85,8 +88,10 @@ class WarmupScheduler(_LRScheduler):
|
||||
if self.finished:
|
||||
if epoch is None:
|
||||
self.after_scheduler.step(None)
|
||||
self._last_lr = self.after_scheduler.get_last_lr()
|
||||
else:
|
||||
self.after_scheduler.step(epoch - self.warmup_epochs)
|
||||
self._last_lr = self.after_scheduler.get_last_lr()
|
||||
else:
|
||||
return super().step(epoch)
|
||||
|
||||
@ -136,7 +141,9 @@ class WarmupDelayerScheduler(_LRScheduler):
|
||||
if self.finished:
|
||||
if epoch is None:
|
||||
self.after_scheduler.step(None)
|
||||
self._last_lr = self.after_scheduler.get_last_lr()
|
||||
else:
|
||||
self.after_scheduler.step(epoch - self.warmup_epochs)
|
||||
self._last_lr = self.after_scheduler.get_last_lr()
|
||||
else:
|
||||
return super().step(epoch)
|
||||
|
@ -12,7 +12,6 @@ class MultiStepLR(_MultiStepLR):
|
||||
number of epoch reaches one of the milestones. Notice that such decay can
|
||||
happen simultaneously with other changes to the learning rate from outside
|
||||
this scheduler. When last_epoch=-1, sets initial lr as lr.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
@ -34,7 +33,6 @@ class MultiStepLR(_MultiStepLR):
|
||||
@LR_SCHEDULERS.register_module
|
||||
class MultiStepWarmupLR(WarmupScheduler):
|
||||
"""Multi-step laerning rate scheduler with warmup.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
|
@ -12,28 +12,21 @@ class OneCycleLR(_OneCycleLR):
|
||||
than the initial learning rate.
|
||||
This policy was initially described in the paper `Super-Convergence:
|
||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
||||
|
||||
The 1cycle learning rate policy changes the learning rate after every batch.
|
||||
`step` should be called after a batch has been used for training.
|
||||
|
||||
This scheduler is not chainable.
|
||||
|
||||
Note also that the total number of steps in the cycle can be determined in one
|
||||
of two ways (listed in order of precedence):
|
||||
|
||||
#. A value for total_steps is explicitly provided.
|
||||
#. A number of epochs (epochs) and a number of steps per epoch
|
||||
(steps_per_epoch) are provided.
|
||||
In this case, the number of total steps is inferred by
|
||||
total_steps = epochs * steps_per_epoch
|
||||
|
||||
You must either provide a value for total_steps or provide a value for both
|
||||
epochs and steps_per_epoch.
|
||||
|
||||
The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
|
||||
claims that "unpublished work has shown even better results by using only two phases". To
|
||||
mimic the behaviour of the original paper instead, set ``three_phase=True``.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
@ -71,7 +64,6 @@ class OneCycleLR(_OneCycleLR):
|
||||
number of *batches* computed, not the total number of epochs computed.
|
||||
When last_epoch=-1, the schedule is started from the beginning, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
|
||||
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
||||
https://arxiv.org/abs/1708.07120
|
||||
"""
|
||||
|
@ -7,7 +7,6 @@ from .delayed import WarmupScheduler
|
||||
@LR_SCHEDULERS.register_module
|
||||
class PolynomialLR(_LRScheduler):
|
||||
"""Polynomial learning rate scheduler.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
@ -43,7 +42,6 @@ class PolynomialLR(_LRScheduler):
|
||||
@LR_SCHEDULERS.register_module
|
||||
class PolynomialWarmupLR(WarmupScheduler):
|
||||
"""Polynomial learning rate scheduler with warmup.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
|
@ -10,7 +10,6 @@ from colossalai.registry import LR_SCHEDULERS
|
||||
class LambdaLR(_LambdaLR):
|
||||
"""Sets the learning rate of each parameter group to the initial lr
|
||||
times a given function. When last_epoch=-1, sets initial lr as lr.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
@ -33,7 +32,6 @@ class LambdaLR(_LambdaLR):
|
||||
class MultiplicativeLR(_MultiplicativeLR):
|
||||
"""Multiply the learning rate of each parameter group by the factor given
|
||||
in the specified function. When last_epoch=-1, sets initial lr as lr
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
@ -58,7 +56,6 @@ class StepLR(_StepLR):
|
||||
step_size epochs. Notice that such decay can happen simultaneously with
|
||||
other changes to the learning rate from outside this scheduler. When
|
||||
last_epoch=-1, sets initial lr as lr
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
@ -82,7 +79,6 @@ class StepLR(_StepLR):
|
||||
class ExponentialLR(_ExponentialLR):
|
||||
"""Decays the learning rate of each parameter group by gamma every epoch.
|
||||
When last_epoch=-1, sets initial lr as lr
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .base_model import BaseModel
|
||||
from .vanilla_resnet import VanillaResNet
|
||||
from .vision_transformer import *
|
||||
from .model_from_config import ModelFromConfig
|
||||
|
||||
__all__ = ['ModelFromConfig']
|
||||
|
@ -8,10 +8,10 @@ import torch.nn as nn
|
||||
from colossalai.builder import build_layer
|
||||
|
||||
|
||||
class BaseModel(nn.Module, ABC):
|
||||
class ModelFromConfig(nn.Module, ABC):
|
||||
|
||||
def __init__(self):
|
||||
super(BaseModel, self).__init__()
|
||||
super(ModelFromConfig, self).__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers_cfg = []
|
||||
|
||||
@ -32,7 +32,6 @@ class BaseModel(nn.Module, ABC):
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
||||
keep_vars=False):
|
||||
|
||||
"""Use this function to override the state dict for
|
||||
saving checkpoints."""
|
||||
return self.state_dict(destination, prefix, keep_vars)
|
@ -1,3 +0,0 @@
|
||||
from .resnet import VanillaResNet
|
||||
|
||||
__all__ = ['VanillaResNet']
|
@ -1,163 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.registry import MODELS
|
||||
from ..base_model import BaseModel
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class VanillaResNet(BaseModel):
|
||||
"""ResNet from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_cls: int,
|
||||
block_type: str,
|
||||
layers: List[int],
|
||||
norm_layer_type: str = 'BatchNorm2d',
|
||||
in_channels: int = 3,
|
||||
groups: int = 1,
|
||||
width_per_group: int = 64,
|
||||
zero_init_residual: bool = False,
|
||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||
dilations=(1, 1, 1, 1)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.inplanes = 64
|
||||
self.zero_init_residual = zero_init_residual
|
||||
self.blocks = layers
|
||||
self.block_expansion = LAYERS.get_module(block_type).expansion
|
||||
self.dilations = dilations
|
||||
self.reslayer_common_cfg = dict(
|
||||
type='ResLayer',
|
||||
block_type=block_type,
|
||||
norm_layer_type=norm_layer_type,
|
||||
groups=groups,
|
||||
base_width=width_per_group
|
||||
)
|
||||
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
|
||||
self.layers_cfg = [
|
||||
# conv1
|
||||
dict(type='Conv2d',
|
||||
in_channels=in_channels,
|
||||
out_channels=self.inplanes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False),
|
||||
# bn1
|
||||
dict(
|
||||
type=norm_layer_type,
|
||||
num_features=self.inplanes
|
||||
),
|
||||
# relu
|
||||
dict(
|
||||
type='ReLU',
|
||||
inplace=True
|
||||
),
|
||||
# maxpool
|
||||
dict(
|
||||
type='MaxPool2d',
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1
|
||||
),
|
||||
# layer 1
|
||||
dict(
|
||||
inplanes=self.inplanes,
|
||||
planes=64,
|
||||
blocks=self.blocks[0],
|
||||
dilation=self.dilations[0],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# layer 2
|
||||
dict(
|
||||
inplanes=64 * self.block_expansion,
|
||||
planes=128,
|
||||
blocks=self.blocks[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0],
|
||||
dilation=self.dilations[1],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# layer 3
|
||||
dict(
|
||||
inplanes=128 * self.block_expansion,
|
||||
planes=256,
|
||||
blocks=layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1],
|
||||
dilation=self.dilations[2],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# layer 4
|
||||
dict(
|
||||
inplanes=256 * self.block_expansion,
|
||||
planes=512,
|
||||
blocks=layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2],
|
||||
dilation=self.dilations[3],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# avg pool
|
||||
dict(
|
||||
type='AdaptiveAvgPool2d',
|
||||
output_size=(1, 1)
|
||||
),
|
||||
# flatten
|
||||
dict(
|
||||
type='LambdaWrapper',
|
||||
func=lambda mod, x: torch.flatten(x, 1)
|
||||
),
|
||||
# linear
|
||||
dict(
|
||||
type='Linear',
|
||||
in_features=512 * self.block_expansion,
|
||||
out_features=num_cls
|
||||
)
|
||||
]
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x,
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if self.zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, LAYERS.get_module('ResNetBottleneck')):
|
||||
# type: ignore[arg-type]
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')):
|
||||
# type: ignore[arg-type]
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
@ -1,3 +0,0 @@
|
||||
from .vision_transformer import VisionTransformerFromConfig
|
||||
|
||||
__all__ = ['VisionTransformerFromConfig']
|
@ -1,14 +1,10 @@
|
||||
from .fp16_optimizer import FP16Optimizer
|
||||
from .colossalai_optimizer import ColossalaiOptimizer
|
||||
from .fused_adam import FusedAdam
|
||||
from .fused_lamb import FusedLAMB
|
||||
from .fused_sgd import FusedSGD
|
||||
from .lamb import Lamb
|
||||
from .lars import Lars
|
||||
from .zero_redundancy_optimizer_level_1 import ZeroRedundancyOptimizer_Level_1
|
||||
from .zero_redundancy_optimizer_level_2 import ZeroRedundancyOptimizer_Level_2
|
||||
from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
|
||||
|
||||
__all__ = [
|
||||
'ZeroRedundancyOptimizer_Level_1', 'ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3',
|
||||
'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'FP16Optimizer', 'Lars'
|
||||
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars'
|
||||
]
|
||||
|
@ -1,168 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
|
||||
try:
|
||||
import colossal_C
|
||||
except:
|
||||
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
|
||||
|
||||
from ..multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
def is_model_parallel_parameter(p):
|
||||
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
|
||||
|
||||
|
||||
def _calc_l2_norm(grads):
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
colossal_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
|
||||
def _calc_lp(grads, norm_type):
|
||||
norm = 0.0
|
||||
for grad in grads:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
norm += grad_norm ** norm_type
|
||||
return norm
|
||||
|
||||
# ======== Gradient Clipping =========
|
||||
|
||||
|
||||
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
"""Clips gradient norm of an iterable of parameters whose gradients
|
||||
are in fp32.
|
||||
|
||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
||||
added functionality to handle model parallel parameters. Note that
|
||||
the gradients are modified in place.
|
||||
|
||||
Arguments:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
|
||||
Returns:
|
||||
Total norm of the parameters (viewed as a single vector).
|
||||
"""
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
||||
# Filter parameters based on:
|
||||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
params = []
|
||||
for param in parameters:
|
||||
if param.grad is not None:
|
||||
# Make sure the grads are in fp32
|
||||
assert param.grad.type() == 'torch.cuda.FloatTensor'
|
||||
params.append(param)
|
||||
# Norm parameters.
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(p.grad.data.abs().max() for p in params)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
if gpc.is_initialized(ParallelMode.TENSOR):
|
||||
# Take max across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
no_tensor_parallel_grads = []
|
||||
for p in params:
|
||||
if is_model_parallel_parameter(p):
|
||||
tensor_parallel_grads.append(p.grad.data)
|
||||
else:
|
||||
no_tensor_parallel_grads.append(p.grad.data)
|
||||
if norm_type == 2.0:
|
||||
tensor_parallel_norm = _calc_l2_norm(
|
||||
tensor_parallel_grads) ** norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(
|
||||
no_tensor_parallel_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||
no_tensor_parallel_grads = _calc_lp(
|
||||
no_tensor_parallel_grads, norm_type)
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
|
||||
# Sum across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(tensor_parallel_norm,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
total_norm = (tensor_parallel_norm +
|
||||
no_tensor_parallel_norm) ** (1.0 / norm_type)
|
||||
if type(total_norm) == 'torch.cuda.FloatTensor':
|
||||
total_norm = total_norm.item()
|
||||
|
||||
# Scale.
|
||||
clip_coeff = max_norm / (total_norm + 1.0e-6)
|
||||
if clip_coeff < 1.0:
|
||||
grads = [p.grad.detach() for p in params]
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale,
|
||||
dummy_overflow_buf,
|
||||
[grads, grads],
|
||||
clip_coeff)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def count_zeros_fp32(parameters):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
||||
# Filter parameters based on:
|
||||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
total_num_zeros = 0.0
|
||||
for param in parameters:
|
||||
grad_not_none = param.grad is not None
|
||||
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
||||
if grad_not_none and is_not_tp_duplicate:
|
||||
grad = param.grad.detach()
|
||||
num_zeros = grad.numel() - torch.count_nonzero(grad)
|
||||
total_num_zeros = num_zeros + total_num_zeros
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(total_num_zeros,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
total_num_zeros = total_num_zeros.item()
|
||||
|
||||
return total_num_zeros
|
||||
|
||||
|
||||
def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
|
||||
for attr in TENSOR_PARALLEL_ATTRIBUTES:
|
||||
if hasattr(src_tensor, attr):
|
||||
val = getattr(src_tensor, attr)
|
||||
setattr(dst_tensor, attr, val)
|
||||
|
||||
|
||||
def param_is_not_tensor_parallel_duplicate(param):
|
||||
return (hasattr(param, IS_TENSOR_PARALLEL) and
|
||||
getattr(param, IS_TENSOR_PARALLEL)) or (
|
||||
gpc.get_local_rank(ParallelMode.TENSOR) == 0)
|
47
colossalai/nn/optimizer/colossalai_optimizer.py
Normal file
47
colossalai/nn/optimizer/colossalai_optimizer.py
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.utils import clip_grad_norm_fp32
|
||||
|
||||
|
||||
class ColossalaiOptimizer(Optimizer):
|
||||
|
||||
def __init__(self, optim: Optimizer):
|
||||
self.optim = optim
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
return self.optim.param_groups
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
return self.optim.defaults
|
||||
|
||||
def add_param_group(self, *args, **kwargs):
|
||||
return self.optim.add_param_group(*args, **kwargs)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
return self.optim.step(*args, **kwargs)
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self.optim.zero_grad(*args, **kwargs)
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
self.optim.load_state_dict(*args, **kwargs)
|
||||
|
||||
def state_dict(self):
|
||||
return self.optim.state_dict()
|
||||
|
||||
def backward(self, loss: Tensor):
|
||||
loss.backward()
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
if max_norm > 0.0:
|
||||
clip_grad_norm_fp32(model.parameters(), max_norm)
|
@ -2,7 +2,7 @@
|
||||
import torch
|
||||
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
from ..multi_tensor_apply import multi_tensor_applier
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module
|
||||
|
@ -2,7 +2,7 @@
|
||||
import torch
|
||||
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
from ..multi_tensor_apply import multi_tensor_applier
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
from torch.optim.optimizer import Optimizer, required
|
||||
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
from ..multi_tensor_apply import multi_tensor_applier
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module
|
||||
|
@ -1,707 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import OPTIMIZER_WRAPPERS
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
|
||||
def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
|
||||
sub_partition_high_limit = (sub_partition_id + 1) * sub_partition_size
|
||||
if sub_partition_high_limit <= flattened_lean_size:
|
||||
return 0
|
||||
else:
|
||||
return min(sub_partition_size, sub_partition_high_limit - flattened_lean_size)
|
||||
|
||||
|
||||
def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_count):
|
||||
group_paddings = []
|
||||
flattened_size = sum([tensor.numel() for tensor in tensor_list])
|
||||
for i in range(sub_partition_count):
|
||||
padding = get_alignment_padding(flattened_size, i, sub_partition_size)
|
||||
group_paddings.append(padding)
|
||||
|
||||
return group_paddings
|
||||
|
||||
|
||||
def _single_range_check(current_index, start_index, end_index, tensor_size):
|
||||
offset = 0
|
||||
if (current_index >= start_index) and (current_index < end_index):
|
||||
# Fully inside bounds
|
||||
return True, offset
|
||||
elif (start_index > current_index) and (start_index < (current_index + tensor_size)):
|
||||
# Partially contained, compute offset
|
||||
offset = start_index - current_index
|
||||
return True, offset
|
||||
else:
|
||||
return False, offset
|
||||
|
||||
|
||||
def _range_check(current_index, element_intervals, tensor_size):
|
||||
results = []
|
||||
for comm_idx, interval in enumerate(element_intervals):
|
||||
start_index, end_index = interval
|
||||
contained, offset = _single_range_check(
|
||||
current_index, start_index, end_index, tensor_size)
|
||||
if contained:
|
||||
results.append((contained, offset, comm_idx))
|
||||
if len(results) == 0:
|
||||
return [(False, 0, -1)]
|
||||
return results
|
||||
|
||||
|
||||
@OPTIMIZER_WRAPPERS.register_module
|
||||
class ZeroRedundancyOptimizer_Level_1(Optimizer):
|
||||
"""
|
||||
ZeroRedundancyOptimizer_Level_1 designed to reduce the memory footprint
|
||||
required for training large deep learning models.
|
||||
|
||||
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
|
||||
https://arxiv.org/abs/1910.02054
|
||||
|
||||
This version aligns with stage-1 in the paper above.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
init_optimizer: Optimizer,
|
||||
dp_parallel_mode: ParallelMode = ParallelMode.DATA,
|
||||
max_elements_per_comm=5e8,
|
||||
verbose=False
|
||||
):
|
||||
# TODO: this class does not work with fp16 AMP_TYPE.PARALLEL, fix it
|
||||
assert get_current_device() != 'cpu', 'ZeRO optimizer cannot be used on CPU only'
|
||||
|
||||
self.flatten = _flatten_dense_tensors
|
||||
self.unflatten = _unflatten_dense_tensors
|
||||
self.optimizer = init_optimizer
|
||||
self.dp_parallel_mode = dp_parallel_mode
|
||||
self.verbose = verbose
|
||||
|
||||
# for compatibility with pytorch optim
|
||||
self.defaults = init_optimizer.defaults
|
||||
|
||||
# param flattened by groups
|
||||
self._param_groups = []
|
||||
self._param_groups_flat = []
|
||||
|
||||
# parallel_sub_partitioned_fp16_groups[group-idx] -> [comm-ids] -> [rank-ids]
|
||||
self.parallel_sub_partitioned_groups = []
|
||||
# same underlying data as above but viewed as: [groups] -> [rank-ids] -> [comm-ids]
|
||||
self.parallel_comm_sub_partitioned_groups = []
|
||||
|
||||
# param partition info
|
||||
# parameters in each group that will not be updated by this process directly
|
||||
self.params_not_local = []
|
||||
|
||||
# parameters that will be updated by this process directly
|
||||
self.params_in_rank_sub_partitions = []
|
||||
|
||||
# parameter offsets for parameters in sub-partitions. Parameter
|
||||
# boundaries may not align with sub-partition boundaries
|
||||
# so we need to keep track of the offsets
|
||||
self.params_in_rank_sub_partitions_offsets = []
|
||||
|
||||
# number of elements per sub-partition in each group
|
||||
self.sub_partition_sizes = []
|
||||
|
||||
# number of communication intervals for each group
|
||||
self.num_comm_intervals_per_group = []
|
||||
|
||||
self.local_rank = gpc.get_local_rank(self.dp_parallel_mode)
|
||||
self.partition_count = self.world_size = gpc.get_world_size(
|
||||
self.dp_parallel_mode)
|
||||
|
||||
self.group_paddings = []
|
||||
self.default_device = self.optimizer.param_groups[0]['params'][0].device
|
||||
|
||||
# max elems per param group
|
||||
self.max_elems_per_comm = []
|
||||
|
||||
# loop to deal with groups
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
# push this group to list before modify
|
||||
self._param_groups.append(param_group['params'])
|
||||
|
||||
# calculate best max elements per comm based to minimize padding
|
||||
self.max_elems_per_comm.append(
|
||||
self.best_max_elems_per_comm(
|
||||
num_elements=sum(t.numel() for t in self._param_groups[i]),
|
||||
max_elements_per_comm=max_elements_per_comm
|
||||
)
|
||||
)
|
||||
|
||||
# flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
|
||||
# RS: create aligned sub-partitions
|
||||
flat_aligned_params = self.flatten_dense_tensors_sub_partition_aligned(
|
||||
tensor_list=self._param_groups[i],
|
||||
max_elements_per_comm=self.max_elems_per_comm[i],
|
||||
)
|
||||
self._param_groups_flat.append(flat_aligned_params)
|
||||
|
||||
updated_params = self.unflatten(self._param_groups_flat[i],
|
||||
self._param_groups[i])
|
||||
for p, q in zip(self._param_groups[i], updated_params):
|
||||
p.data = q.data
|
||||
|
||||
# divide the flat weights into near equal partition equal to the data parallel degree
|
||||
# each process will compute on a different part of the partition
|
||||
# RS: split into two layer list -> [comm-id] -> [sub-partitions per rank]
|
||||
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
|
||||
self.get_data_parallel_sub_partitions(
|
||||
tensor=self._param_groups_flat[i],
|
||||
max_elements_per_comm=self.max_elems_per_comm[i],
|
||||
)
|
||||
self.parallel_comm_sub_partitioned_groups.append(
|
||||
comm_partitions) # comm -> rank
|
||||
self.parallel_sub_partitioned_groups.append(
|
||||
dp_sub_partitions) # rank -> comm
|
||||
self.sub_partition_sizes.append(sub_partition_size)
|
||||
self.num_comm_intervals_per_group.append(num_comm_intervals)
|
||||
|
||||
# Compute sub_partition paddings
|
||||
sub_partition_paddings = get_group_alignment_padding(
|
||||
tensor_list=self._param_groups[i],
|
||||
sub_partition_size=sub_partition_size,
|
||||
sub_partition_count=num_comm_intervals * self.partition_count)
|
||||
self.group_paddings.append(sub_partition_paddings)
|
||||
|
||||
# modify optimizer of have flat master weight
|
||||
param_group['params'] = self.parallel_sub_partitioned_groups[i][self.local_rank]
|
||||
|
||||
# RS: divide up the sub-partitions and keep track of offsets for each param
|
||||
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
|
||||
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local = self.get_all_sub_partition_info(
|
||||
tensor_list=self._param_groups[i],
|
||||
all_element_intervals=element_intervals,
|
||||
)
|
||||
|
||||
self.params_in_rank_sub_partitions.append(
|
||||
params_in_rank_sub_partition)
|
||||
self.params_not_local.append(params_not_local)
|
||||
self.params_in_rank_sub_partitions_offsets.append(
|
||||
params_in_rank_sub_partitions_offsets)
|
||||
|
||||
self.local_sub_partitions_of_groups = [
|
||||
group[self.local_rank] for group in self.parallel_sub_partitioned_groups]
|
||||
self._initialize_optimizer_states()
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.optimizer.state
|
||||
|
||||
@state.setter
|
||||
def state(self, value):
|
||||
self.optimizer.state = value
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
# LSG: return the full param groups instead of local partitions
|
||||
# of the param groups for compatibility with torch.cuda.amp
|
||||
param_groups = []
|
||||
|
||||
for group_id, group in enumerate(self.optimizer.param_groups):
|
||||
group_containing_all_param = {
|
||||
'params': self._param_groups[group_id],
|
||||
**{k: v for k, v in group.items() if k != 'params'}
|
||||
}
|
||||
# LSG: for compatibility with unknown bug with lr scheduler
|
||||
# TODO: fix this
|
||||
group_containing_all_param.setdefault('initial_lr', group['lr'])
|
||||
param_groups.append(group_containing_all_param)
|
||||
return param_groups
|
||||
|
||||
@param_groups.setter
|
||||
def param_groups(self, value):
|
||||
self.optimizer.param_groups = value
|
||||
|
||||
def _initialize_optimizer_states(self):
|
||||
for group_idx, group in enumerate(self.local_sub_partitions_of_groups):
|
||||
for idx, sub_partition_param in enumerate(group):
|
||||
sub_partition_grad = torch.zeros(int(
|
||||
self.sub_partition_sizes[group_idx]),
|
||||
dtype=sub_partition_param.dtype).cuda()
|
||||
sub_partition_param.grad = sub_partition_grad
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
# LSG: comment out for compatibility with torch.cuda.amp
|
||||
# for group in self.local_sub_partitions_of_groups:
|
||||
# for idx, sub_partition_param in enumerate(group):
|
||||
# sub_partition_param.grad = None
|
||||
|
||||
def best_max_elems_per_comm(self, num_elements, max_elements_per_comm):
|
||||
# if we use max-elems-per-comm as is, how many comm intervals will there be
|
||||
max_comm_intervals = math.ceil(num_elements / max_elements_per_comm)
|
||||
padding_for_max_comm = (max_elements_per_comm *
|
||||
max_comm_intervals) - num_elements
|
||||
|
||||
# if we use 1 less comm interval how much extra comm padding would be required
|
||||
min_comm_intervals = num_elements // max_elements_per_comm
|
||||
if min_comm_intervals == 0:
|
||||
if self.verbose:
|
||||
print_rank_0(
|
||||
f'Using default max_elements_per_comm {max_elements_per_comm}')
|
||||
return max_elements_per_comm
|
||||
|
||||
padding_for_min_comm = math.ceil(
|
||||
num_elements / (self.world_size * min_comm_intervals))
|
||||
|
||||
# choose padding that uses least amount of overhead
|
||||
if padding_for_max_comm > padding_for_min_comm:
|
||||
new_max_elements_per_comm = padding_for_min_comm + max_elements_per_comm
|
||||
if self.verbose:
|
||||
print_rank_0(
|
||||
f'Updating max_elements_per_comm from {max_elements_per_comm} -> {new_max_elements_per_comm}')
|
||||
return new_max_elements_per_comm
|
||||
else:
|
||||
if self.verbose:
|
||||
print_rank_0(
|
||||
f'Using default max_elements_per_comm {max_elements_per_comm}')
|
||||
return max_elements_per_comm
|
||||
|
||||
def get_data_parallel_sub_partitions(self,
|
||||
tensor,
|
||||
max_elements_per_comm,
|
||||
):
|
||||
total_num_elements = tensor.numel()
|
||||
|
||||
# if total elements is less than our max, revert to splitting into dp partitions
|
||||
max_elements_per_comm = min(total_num_elements, max_elements_per_comm)
|
||||
sub_partition_size = int(max_elements_per_comm // self.world_size)
|
||||
|
||||
# Ensure partition alignment was done correctly
|
||||
num_sub_partitions = int(total_num_elements // sub_partition_size)
|
||||
assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format(total_num_elements,
|
||||
sub_partition_size)
|
||||
|
||||
# Ensure comm interval alignment was done correctly.
|
||||
num_comm_intervals = int(num_sub_partitions // self.world_size)
|
||||
assert num_sub_partitions % self.world_size == 0, "{} % {} != 0".format(
|
||||
num_sub_partitions, self.world_size)
|
||||
|
||||
if self.verbose:
|
||||
print_rank_0("**** partition info:")
|
||||
print_rank_0(f"\t total_num_elements={total_num_elements}")
|
||||
print_rank_0(f"\t world_size={self.world_size}")
|
||||
print_rank_0(f"\t max_elements_per_comm={max_elements_per_comm}")
|
||||
print_rank_0(f"\t sub_partition_size={sub_partition_size}")
|
||||
print_rank_0(f"\t num_sub_partitions={num_sub_partitions}")
|
||||
print_rank_0(f"\t num_comm_intervals={num_comm_intervals}")
|
||||
print_rank_0("****")
|
||||
|
||||
# [comm_id] -> [rank]
|
||||
comm_partitions = []
|
||||
for _ in range(num_comm_intervals):
|
||||
comm_partitions.append([])
|
||||
|
||||
start = 0
|
||||
comm_id = 0
|
||||
element_intervals = defaultdict(
|
||||
list) # [rank] -> [(start,end), (start,end), ...]
|
||||
for idx in range(num_sub_partitions):
|
||||
rank_id = idx % self.world_size
|
||||
sub_partition = tensor.narrow(
|
||||
0, start, sub_partition_size).detach()
|
||||
element_intervals[rank_id].append(
|
||||
(start, start + sub_partition_size))
|
||||
comm_partitions[comm_id].append(sub_partition)
|
||||
start = start + sub_partition_size
|
||||
if rank_id == (self.world_size - 1):
|
||||
comm_id += 1
|
||||
|
||||
# [rank] -> [comm_id]
|
||||
sub_partitions = []
|
||||
for _ in range(self.world_size):
|
||||
sub_partitions.append([])
|
||||
for comm_id, partitions in enumerate(comm_partitions):
|
||||
for rank_id, partition in enumerate(partitions):
|
||||
sub_partitions[rank_id].append(partition)
|
||||
|
||||
return comm_partitions, sub_partitions, element_intervals, sub_partition_size, num_comm_intervals
|
||||
|
||||
def get_all_sub_partition_info(self,
|
||||
tensor_list,
|
||||
all_element_intervals,
|
||||
):
|
||||
params_not_local = []
|
||||
|
||||
# [rank] -> [comm-id] -> [param/offset]
|
||||
params_in_rank_sub_partition = []
|
||||
params_in_rank_sub_partitions_offsets = []
|
||||
|
||||
for rank in range(self.world_size):
|
||||
params_in_local_sub_partition = []
|
||||
local_sub_partition_offsets = []
|
||||
comm_tensor_list = []
|
||||
comm_offset_list = []
|
||||
current_index = 0
|
||||
prev_comm_idx = 0
|
||||
for iii, tensor in enumerate(tensor_list):
|
||||
tensor_size = tensor.numel()
|
||||
results_list = _range_check(current_index,
|
||||
all_element_intervals[rank],
|
||||
tensor_size)
|
||||
for contained, offset, comm_idx in results_list:
|
||||
if contained:
|
||||
if prev_comm_idx != comm_idx:
|
||||
params_in_local_sub_partition.append(
|
||||
comm_tensor_list)
|
||||
comm_tensor_list = []
|
||||
local_sub_partition_offsets.append(
|
||||
comm_offset_list)
|
||||
comm_offset_list = []
|
||||
comm_tensor_list.append(tensor)
|
||||
comm_offset_list.append(offset)
|
||||
prev_comm_idx = comm_idx
|
||||
elif rank == self.local_rank:
|
||||
params_not_local.append(tensor)
|
||||
|
||||
current_index = current_index + tensor_size
|
||||
|
||||
# assert len(comm_tensor_list) > 0
|
||||
# assert len(comm_offset_list) > 0
|
||||
params_in_local_sub_partition.append(comm_tensor_list)
|
||||
local_sub_partition_offsets.append(comm_offset_list)
|
||||
|
||||
params_in_rank_sub_partition.append(params_in_local_sub_partition)
|
||||
params_in_rank_sub_partitions_offsets.append(
|
||||
local_sub_partition_offsets)
|
||||
|
||||
return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local
|
||||
|
||||
def get_flat_sub_partitions(self,
|
||||
comm_tensor_list,
|
||||
comm_param_offsets,
|
||||
sub_partition_size,
|
||||
dtype,
|
||||
default_device,
|
||||
num_comm_intervals=None,
|
||||
return_partition_params=False):
|
||||
partition_params = []
|
||||
final_param_offsets = []
|
||||
flat_sub_partitions = []
|
||||
for tensor_list, param_offsets in zip(comm_tensor_list, comm_param_offsets):
|
||||
flat_tensor_list = []
|
||||
current_size = 0
|
||||
my_offsets = []
|
||||
my_params = []
|
||||
|
||||
for i, tensor in enumerate(tensor_list):
|
||||
if tensor.grad is None:
|
||||
tensor.grad = torch.zeros(tensor.size(),
|
||||
dtype=tensor.dtype,
|
||||
device=tensor.device)
|
||||
param = tensor
|
||||
tensor = tensor.grad
|
||||
num_elements = tensor.numel()
|
||||
tensor_offset = 0
|
||||
|
||||
# we need to offset to get to the right element
|
||||
if i == 0 and param_offsets[i] > 0:
|
||||
tensor_offset = param_offsets[i]
|
||||
num_elements = num_elements - tensor_offset
|
||||
|
||||
# We don't need all elements of the tensor if this tensor is
|
||||
# larger than we have space for in our curr sub-partition
|
||||
if num_elements > (sub_partition_size - current_size):
|
||||
num_elements = sub_partition_size - current_size
|
||||
|
||||
# we need a narrow view of the tensor based on the tensor offset and number of elements that
|
||||
# we need from this tensor
|
||||
if tensor_offset > 0 or num_elements < tensor.numel():
|
||||
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
|
||||
0,
|
||||
int(tensor_offset),
|
||||
int(num_elements)).to(dtype))
|
||||
else:
|
||||
flat_tensor_list.append(tensor.to(dtype))
|
||||
my_params.append(param)
|
||||
|
||||
# remember offset into partition and #elems for this tensor
|
||||
my_offsets.append((current_size, num_elements))
|
||||
|
||||
current_size = current_size + num_elements
|
||||
|
||||
# this means its the last partition and does not align with the dp boundary. We need to pad before flattening
|
||||
if current_size < sub_partition_size:
|
||||
my_offsets.append((None, None))
|
||||
my_params.append(None)
|
||||
if len(tensor_list) == 0:
|
||||
assert default_device != None
|
||||
flat_tensor_list.append(
|
||||
torch.zeros(int(sub_partition_size - current_size),
|
||||
dtype=dtype,
|
||||
device=default_device))
|
||||
else:
|
||||
flat_tensor_list.append(
|
||||
torch.zeros(int(sub_partition_size - current_size),
|
||||
dtype=dtype,
|
||||
device=tensor_list[0].device))
|
||||
partition_params.append(my_params) # flat_tensor_list)
|
||||
final_param_offsets.append(my_offsets)
|
||||
assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(
|
||||
len(flat_tensor_list), len(my_offsets))
|
||||
flat_sub_partitions.append(self.flatten(flat_tensor_list))
|
||||
if num_comm_intervals is not None and len(
|
||||
flat_sub_partitions) < num_comm_intervals:
|
||||
# print("padding w. sub partitions to ensure uniform communication")
|
||||
device = flat_sub_partitions[0].device
|
||||
for _ in range(num_comm_intervals - len(flat_sub_partitions)):
|
||||
flat_sub_partitions.append(
|
||||
torch.zeros(int(sub_partition_size),
|
||||
dtype=dtype,
|
||||
device=device))
|
||||
partition_params.append([None])
|
||||
final_param_offsets.append([(None, None)])
|
||||
|
||||
if return_partition_params:
|
||||
assert len(flat_sub_partitions) == len(partition_params)
|
||||
assert len(partition_params) == len(final_param_offsets), "{} {}".format(len(partition_params),
|
||||
len(final_param_offsets))
|
||||
return flat_sub_partitions, partition_params, final_param_offsets
|
||||
return flat_sub_partitions
|
||||
|
||||
def zero_grad(self, set_grads_to_None=False):
|
||||
"""
|
||||
Zero FP16 parameter grads.
|
||||
"""
|
||||
# FP32 grad should never exist.
|
||||
# For speed, set model fp16 grad to None by default
|
||||
for group in self._param_groups:
|
||||
for p in group:
|
||||
if set_grads_to_None:
|
||||
p.grad = None
|
||||
else:
|
||||
if p.grad is not None:
|
||||
p.grad.detach_()
|
||||
p.grad.zero_()
|
||||
|
||||
def free_grad_in_param_list(self, param_list):
|
||||
for p in param_list:
|
||||
if isinstance(p, list):
|
||||
for _p in p:
|
||||
_p.grad = None
|
||||
else:
|
||||
p.grad = None
|
||||
|
||||
def flatten_dense_tensors_sub_partition_aligned(self,
|
||||
tensor_list,
|
||||
max_elements_per_comm
|
||||
):
|
||||
assert max_elements_per_comm >= self.world_size, f"max_elements_per_comm {max_elements_per_comm} < dp {self.world_size}"
|
||||
|
||||
num_elements = sum(t.numel() for t in tensor_list)
|
||||
|
||||
# Compute aligned partition size based on parameter count
|
||||
aligned_param_partition_size = math.ceil(
|
||||
num_elements / self.world_size)
|
||||
|
||||
# Compute aligned partition size based on communication size
|
||||
aligned_comm_partition_size = int(
|
||||
max_elements_per_comm // self.world_size)
|
||||
|
||||
if aligned_param_partition_size <= aligned_comm_partition_size:
|
||||
sub_partition_count = 1
|
||||
sub_partition_size = aligned_param_partition_size
|
||||
else:
|
||||
sub_partition_count = math.ceil(aligned_param_partition_size /
|
||||
aligned_comm_partition_size)
|
||||
sub_partition_size = aligned_comm_partition_size
|
||||
|
||||
# Compute required padding for alignment to dp and max_elements_per_comm
|
||||
padding = (sub_partition_count * sub_partition_size *
|
||||
self.world_size) - num_elements
|
||||
|
||||
if self.verbose:
|
||||
print_rank_0(
|
||||
f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}")
|
||||
print_rank_0(
|
||||
f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}")
|
||||
|
||||
if padding == 0:
|
||||
aligned_tensor_list = tensor_list
|
||||
else:
|
||||
pad_tensor = torch.zeros(padding,
|
||||
device=tensor_list[0].device,
|
||||
dtype=tensor_list[0].dtype)
|
||||
aligned_tensor_list = tensor_list + [pad_tensor]
|
||||
|
||||
flat_tensors = self.flatten(aligned_tensor_list)
|
||||
return flat_tensors
|
||||
|
||||
# def reduce_gradients(self):
|
||||
# # LSG: this reduce gradients method no longer works
|
||||
# # after code change, please use DataParallelGradientHandler instead
|
||||
#
|
||||
# world_size = gpc.get_world_size(self.parallel_mode)
|
||||
# local_rank = gpc.get_local_rank(self.parallel_mode)
|
||||
#
|
||||
# for i, group in enumerate(self._param_groups):
|
||||
# num_comm_intervals = self.num_comm_intervals_per_group[i]
|
||||
# all_sub_partitions = []
|
||||
# for rank in range(world_size):
|
||||
# # gsp is list of partitions indexed by comm_idx
|
||||
# grad_sub_partitions = self.get_flat_sub_partitions(
|
||||
# comm_tensor_list=self.params_in_rank_sub_partitions[i][rank],
|
||||
# comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][rank],
|
||||
# dtype=self.local_sub_partitions_of_groups[i][0].dtype,
|
||||
# default_device=self.default_device,
|
||||
# sub_partition_size=self.sub_partition_sizes[i],
|
||||
# num_comm_intervals=self.num_comm_intervals_per_group[i])
|
||||
# all_sub_partitions.append(grad_sub_partitions)
|
||||
#
|
||||
# assert len(grad_sub_partitions) == num_comm_intervals
|
||||
#
|
||||
# local_comm_partitions = []
|
||||
# for comm_idx in range(num_comm_intervals):
|
||||
# single_comm_all_partitions = []
|
||||
# for rank in range(world_size):
|
||||
# single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx])
|
||||
#
|
||||
# for partition in single_comm_all_partitions:
|
||||
# partition.div_(world_size)
|
||||
#
|
||||
# dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
|
||||
# input_list=single_comm_all_partitions,
|
||||
# group=gpc.get_group(self.parallel_mode))
|
||||
|
||||
def step(self, closure=None):
|
||||
local_sub_partitions_grad_groups = []
|
||||
|
||||
for i, group in enumerate(self._param_groups):
|
||||
# RS: update free grads w.r.t. sub partitions
|
||||
# free gradients for all the parameters that are not updated by this process
|
||||
self.free_grad_in_param_list(self.params_not_local[i])
|
||||
|
||||
# create flat gradient partitions for parameters updated by this process
|
||||
local_grad_sub_partitions = self.get_flat_sub_partitions(
|
||||
comm_tensor_list=self.params_in_rank_sub_partitions[i][self.local_rank],
|
||||
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][self.local_rank],
|
||||
sub_partition_size=self.sub_partition_sizes[i],
|
||||
dtype=self.local_sub_partitions_of_groups[i][0].dtype,
|
||||
num_comm_intervals=self.num_comm_intervals_per_group[i],
|
||||
default_device=self.default_device)
|
||||
|
||||
# RS: update all our local params with sub-partition grads
|
||||
for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_groups[i]):
|
||||
sub_partition_param.grad = local_grad_sub_partitions[idx]
|
||||
|
||||
# RS: update free grads for sub-partitions
|
||||
# release all the gradient since we have already created a necessary copy in dp_grad_partition
|
||||
self.free_grad_in_param_list(
|
||||
self.params_in_rank_sub_partitions[i][self.local_rank])
|
||||
|
||||
local_sub_partitions_grad_groups.append(local_grad_sub_partitions)
|
||||
|
||||
if closure is None:
|
||||
loss = self.optimizer.step()
|
||||
else:
|
||||
loss = self.optimizer.step(closure=closure)
|
||||
|
||||
# RS: clear our sub partition grads
|
||||
# LSG: not needed as amp is used instead
|
||||
# get rid of the fp32 gradients. Not needed anymore
|
||||
# for group in self.local_sub_partitions_of_groups:
|
||||
# for idx, sub_partition_param in enumerate(group):
|
||||
# sub_partition_param.grad = None
|
||||
|
||||
# RS: all_gather/broadcast sub-partitions in separate comm calls
|
||||
# gather the updated weights from everyone
|
||||
for all_sub_partitions in self.parallel_comm_sub_partitioned_groups:
|
||||
for comm_id, sub_partitions in enumerate(all_sub_partitions):
|
||||
dist.all_gather(sub_partitions,
|
||||
sub_partitions[self.local_rank],
|
||||
group=gpc.get_group(self.dp_parallel_mode))
|
||||
|
||||
# TODO: we probably don't need this? just to be safe
|
||||
for i in range(len(self._param_groups)):
|
||||
updated_params = self.unflatten(self._param_groups_flat[i],
|
||||
self._param_groups[i])
|
||||
for p, q in zip(self._param_groups[i], updated_params):
|
||||
p.data = q.data
|
||||
|
||||
return loss
|
||||
|
||||
def _rigid_state_dict(self):
|
||||
"""Returns a dict that can be loaded for continued training with same DP degree
|
||||
|
||||
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
|
||||
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
|
||||
of the contained Pytorch optimizer.
|
||||
|
||||
Example::
|
||||
|
||||
checkpoint = {}
|
||||
checkpoint['model'] = model.state_dict()
|
||||
checkpoint['optimizer'] = optimizer.state_dict()
|
||||
torch.save(checkpoint, "saved.pth")
|
||||
"""
|
||||
state_dict = {}
|
||||
for k, v in self.optimizer.state_dict().items():
|
||||
state_dict[k] = v
|
||||
state_dict[
|
||||
'local_sub_partitions_of_groups'] = self.local_sub_partitions_of_groups
|
||||
return state_dict
|
||||
|
||||
def state_dict(self):
|
||||
"""
|
||||
Returns a dict containing the current state of this Optimizer instance.
|
||||
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
|
||||
of the contained Pytorch optimizer.
|
||||
|
||||
Example::
|
||||
|
||||
checkpoint = {}
|
||||
checkpoint['model'] = model.state_dict()
|
||||
checkpoint['optimizer'] = optimizer.state_dict()
|
||||
torch.save(checkpoint, "saved.pth")
|
||||
"""
|
||||
return self._rigid_state_dict()
|
||||
|
||||
def load_state_dict(self,
|
||||
state_dict,
|
||||
load_optimizer_states=True,
|
||||
):
|
||||
"""
|
||||
Loads a state_dict created by an earlier call to state_dict().
|
||||
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
|
||||
whose parameters in turn came from ``model``, it is expected that the user
|
||||
will call ``model.load_state_dict()`` before
|
||||
``fp16_optimizer_instance.load_state_dict()`` is called.
|
||||
|
||||
Example::
|
||||
|
||||
model = torch.nn.Linear(D_in, D_out).cuda().half()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
|
||||
...
|
||||
checkpoint = torch.load("saved.pth")
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
"""
|
||||
self._rigid_load_state_dict(
|
||||
state_dict,
|
||||
load_optimizer_states)
|
||||
|
||||
def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
|
||||
# I think it should actually be ok to reload the optimizer before the model.
|
||||
state_dict_ = state_dict.copy()
|
||||
local_sub_partitions_of_groups = state_dict_.pop(
|
||||
'local_sub_partitions_of_groups')
|
||||
|
||||
if load_optimizer_states:
|
||||
self.optimizer.load_state_dict(state_dict_)
|
||||
|
||||
for curr_group, saved_group in zip(self.local_sub_partitions_of_groups,
|
||||
local_sub_partitions_of_groups):
|
||||
for curr_param, saved_param in zip(curr_group, saved_group):
|
||||
curr_param.data.copy_(saved_param.data)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user