Develop/experiments (#59)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
Frank Lee 2021-12-09 15:08:29 +08:00 committed by GitHub
parent eb2f8b1f6b
commit da01c234e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
229 changed files with 6532 additions and 8741 deletions

View File

@ -1,4 +1,4 @@
from .initialize import init_dist, initialize
from .nn import *
from .initialize import (initialize, launch, launch_from_openmpi,
launch_from_slurm, launch_from_torch, get_default_parser)
__version__ = '0.0.1'

View File

@ -0,0 +1,32 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .amp_type import AMP_TYPE
from colossalai.context import Config
import torch.nn as nn
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from .torch_amp import convert_to_torch_amp
from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp
def convert_to_amp(model: nn.Module,
optimizer: Optimizer,
criterion: _Loss,
mode: AMP_TYPE,
amp_config: Config = None):
assert isinstance(mode, AMP_TYPE), \
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
if amp_config is None:
amp_config = Config()
if mode == AMP_TYPE.TORCH:
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
elif mode == AMP_TYPE.APEX:
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
elif mode == AMP_TYPE.NAIVE:
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
return model, optimizer, criterion

View File

@ -7,4 +7,4 @@ from enum import Enum
class AMP_TYPE(Enum):
APEX = 'apex'
TORCH = 'torch'
PARALLEL = 'parallel'
NAIVE = 'naive'

View File

@ -0,0 +1,15 @@
from .apex_amp import ApexAMPOptimizer
import torch.nn as nn
from torch.optim import Optimizer
import apex.amp as apex_amp
def convert_to_apex_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
optimizer = ApexAMPOptimizer(optimizer)
return model, optimizer
__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']

View File

@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
try:
import apex.amp as apex_amp
except:
pass
from torch import Tensor
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
class ApexAMPOptimizer(ColossalaiOptimizer):
def backward(self, loss: Tensor):
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
scaled_loss.backward()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if max_norm > 0:
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)

View File

@ -0,0 +1,20 @@
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
def convert_to_naive_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
if is_no_pp_or_last_stage():
model = NaiveAMPModel(model, output_to_fp32=True)
else:
model = NaiveAMPModel(model, output_to_fp32=False)
optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
return model, optimizer
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer']

View File

@ -12,11 +12,9 @@ from torch.optim import Optimizer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.registry import OPTIMIZER_WRAPPERS
from colossalai.utils import print_rank_0
from ._utils import copy_tensor_parallel_attributes, clip_grad_norm_fp32, count_zeros_fp32
from ..multi_tensor_apply import multi_tensor_applier
from colossalai.logging import get_dist_logger
from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)
def _zero_grad_group_helper(group, set_to_none):
@ -92,7 +90,7 @@ class DynamicGradScaler:
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
self._logger = get_global_dist_logger()
self._logger = get_dist_logger()
@property
def scale(self):
@ -113,7 +111,7 @@ class DynamicGradScaler:
if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor,
self.min_scale)
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}')
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
@ -125,10 +123,10 @@ class DynamicGradScaler:
# and scale up the loss scale.
if self._max_scale is not None and self._scale >= self._max_scale:
self._logger.info(
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed')
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
else:
self._scale = self._scale * self.growth_factor
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}')
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
def state_dict(self):
state_dict = {}
@ -145,7 +143,6 @@ class DynamicGradScaler:
self._max_scale = state_dict['max_scale']
@OPTIMIZER_WRAPPERS.register_module
class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types.
@ -184,13 +181,13 @@ class FP16Optimizer(Optimizer):
max_scale: int = 2 ** 32):
# default args for compatibility
bf16 = False
params_have_main_grad = False
params_have_main_grad = True
# have a defaults for compatibility with pytorch optim
self.defaults = optimizer.defaults
# log config
self._logger = get_global_dist_logger()
self._logger = get_dist_logger()
self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
f"Optimizer: {optimizer.__class__.__name__}\n"
f"clip_grad = {clip_grad}\n"
@ -328,6 +325,7 @@ class FP16Optimizer(Optimizer):
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
# For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups:
@ -387,10 +385,6 @@ class FP16Optimizer(Optimizer):
@torch.no_grad()
def step(self):
# for param_group in self.float16_groups:
# for param in param_group:
# print(param.grad is None)
# Copy gradients from model params to main params.
self._copy_model_grads_to_main_grads()

View File

@ -0,0 +1,65 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, List, Any, Dict
from torch.optim import Optimizer
import torch.cuda.amp as torch_amp
from colossalai.nn.optimizer import ColossalaiOptimizer
from ._fp16_optimizer import FP16Optimizer
class NaiveAMPOptimizer(ColossalaiOptimizer):
def __init__(self, optim: Optimizer, *args, **kwargs):
optim = FP16Optimizer(optimizer=optim, *args, **kwargs)
super().__init__(optim)
def backward(self, loss: Tensor):
loss = self.optim.scale_loss(loss)
loss.backward()
def step(self):
self.optim.step()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
pass
class NaiveAMPModel(nn.Module):
def __init__(self,
model: nn.Module,
output_to_fp32: bool = True):
super().__init__()
self.model = model.half()
self._output_to_fp32 = output_to_fp32
def _convert_to_fp16(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
input_ = input_.half()
return input_
def _convert_to_fp32(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
input_ = input_.float()
return input_
def forward(self, *args, **kwargs):
if args:
args = [self._convert_to_fp16(arg) for arg in args]
if kwargs:
for k, v in kwargs.items():
kwargs[k] = self._convert_to_fp16(v)
out = self.model(*args, **kwargs)
if self._output_to_fp32:
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
return out

View File

@ -0,0 +1,18 @@
import torch.nn as nn
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from colossalai.context import Config
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
def convert_to_torch_amp(model: nn.Module,
optimizer: Optimizer,
criterion: _Loss,
amp_config: Config):
model = TorchAMPModel(model)
optimizer = TorchAMPOptimizer(optimizer, **amp_config)
criterion = TorchAMPLoss(criterion)
return model, optimizer, criterion
__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer']

View File

@ -1,4 +1,8 @@
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
# to support tensor parallel
import torch
from collections import defaultdict, abc
import warnings

View File

@ -0,0 +1,54 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
import torch.cuda.amp as torch_amp
from torch import Tensor
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from ._grad_scaler import GradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
class TorchAMPOptimizer(ColossalaiOptimizer):
def __init__(self, optim: Optimizer, *args, **kwargs):
super().__init__(optim)
self.scaler = GradScaler(*args, **kwargs)
def backward(self, loss: Tensor):
self.scaler.scale(loss).backward()
def step(self):
self.scaler.step(self.optim)
self.scaler.update()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if max_norm > 0.0:
self.scaler.unscale_(self.optim)
clip_grad_norm_fp32(model.parameters(), max_norm)
class TorchAMPModel(nn.Module):
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
@torch_amp.autocast()
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
class TorchAMPLoss(nn.Module):
def __init__(self, loss: _Loss):
super().__init__()
self.loss = loss
@torch_amp.autocast()
def forward(self, *args, **kwargs):
return self.loss(*args, **kwargs)

View File

@ -1,10 +1,10 @@
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper,
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler)
from .pipeline import ModelInitializer
from .pipeline import PipelineModelInitializer
__all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper',
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
'build_gradient_handler', 'ModelInitializer'
'build_gradient_handler', 'PipelineModelInitializer'
]

View File

@ -106,7 +106,7 @@ def build_dataset(config):
return build_from_registry(config, DATASETS)
def build_optimizer(config, model, params: Iterable = None, need_module=False):
def build_optimizer(config, model):
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
'model' and 'params'.
@ -115,23 +115,12 @@ def build_optimizer(config, model, params: Iterable = None, need_module=False):
:type config: dict or :class:`colossalai.context.Config`
:param model: A model containing parameters for the optimizer
:type model: :class:`nn.Module`
:param params: A dict containing parameters for the optimizer
:type params: dict, optional
:param need_module: Indicates whether the optimizer needs a module
:type params: bool, optional
:raises AssertionError: Raises an AssertionError if both `model` and `params` are None
:return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer`
"""
assert model is not None or params is not None, 'arguments model and params can not both be None'
if need_module:
config['module'] = model
elif model is not None:
config['params'] = model.parameters()
elif params is not None:
config['params'] = params
return build_from_registry(config, OPTIMIZERS)
config_ = config.copy()
config_['params'] = model.parameters()
return build_from_registry(config_, OPTIMIZERS)
def build_gradient_handler(config, model, optimizer):
@ -149,8 +138,9 @@ def build_gradient_handler(config, model, optimizer):
:rtype: :class:`BaseGradientHandler`
"""
config_ = config.copy()
mod_type = config_.pop('type')
return GRADIENT_HANDLER.get_module(mod_type)(model, optimizer, **config_)
config_['model'] = model
config_['optimizer'] = optimizer
return build_from_registry(config_, GRADIENT_HANDLER)
def build_hooks(config, trainer):
@ -164,8 +154,9 @@ def build_hooks(config, trainer):
:return: An object of :class:`BaseHook`
:rtype: :class:`BaseHook`
"""
config['trainer'] = trainer
return build_from_registry(config, HOOKS)
config_ = config.copy()
config_['trainer'] = trainer
return build_from_registry(config_, HOOKS)
def build_transform(config):
@ -195,32 +186,8 @@ def build_data_sampler(config, dataset):
:rtype: :class:`colossalai.nn.data.sampler.BaseSampler`
"""
config_ = config.copy()
mod_type = config_.pop('type')
return SAMPLERS.get_module(mod_type)(dataset, **config_)
def build_optimizer_wrapper(config, optimizer, model=None):
"""Returns an optimizer wrapper object of :class:`torch.optim.Optimizer` constructed
from `config`, `model` and `optimizer`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param optimizer: An optimizer object containing parameters for the gradient handler
:type optimizer: :class:`torch.optim.Optimizer`
:param model: A model containing parameters for the gradient handler
:type model: :class:`nn.Module`, optional
:return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer`
"""
config_ = config.copy()
mod_type = config_.pop('type')
# LSG: special treatment for zeor level 3
if mod_type == 'ZeroRedundancyOptimizer_Level_3':
return OPTIMIZER_WRAPPERS.get_module(mod_type)(model, optimizer, **config_)
else:
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
config_['dataset'] = dataset
return build_from_registry(config_, DATA_SAMPLERS)
def build_lr_scheduler(config, optimizer):
@ -241,8 +208,8 @@ def build_lr_scheduler(config, optimizer):
:rtype: :class:`torch.optim.lr_scheduler`
"""
config_ = config.copy()
mod_type = config_.pop('type')
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
config_['optimizer'] = optimizer
return build_from_registry(config_, LR_SCHEDULERS)
def build_schedule(config):

View File

@ -4,7 +4,7 @@ import heapq
from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.logging import get_dist_logger
from colossalai.utils import set_to_cuda
@ -111,21 +111,21 @@ def _binary_search(weights, num):
return intervals
def _partition_uniform(num_items, num_parts, num_chunks):
def _partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
logger = get_global_dist_logger()
parts = [[] for _ in range(num_parts)]
logger = get_dist_logger()
parts = [[] for _ in range(pipeline_parallel_size)]
partition_items = num_items // num_chunks
for idx in range(num_chunks):
base_idx = idx * partition_items
chunk_size = partition_items // num_parts
left = num_parts - partition_items % num_parts
chunk_size = partition_items // pipeline_parallel_size
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
if chunk_size == 0:
logger.warning("Some nodes in Pipeline have no requests")
for p in range(num_parts):
for p in range(pipeline_parallel_size):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))
@ -133,34 +133,34 @@ def _partition_uniform(num_items, num_parts, num_chunks):
return parts
def _partition_balanced(weights, num_parts, num_chunks):
num_total = num_parts * num_chunks
def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
num_total = pipeline_parallel_size * num_chunks
num_items = len(weights)
if num_items <= num_total:
return _partition_uniform(num_items, num_parts, num_chunks)
return _partition_uniform(num_items, pipeline_parallel_size, num_chunks)
intervals = _binary_search(weights, num_total)
current = 0
parts = [[] for _ in range(num_parts)]
parts = [[] for _ in range(pipeline_parallel_size)]
for inter in intervals:
parts[current].append(inter)
current = (current + 1) % num_parts
current = (current + 1) % pipeline_parallel_size
return parts
class ModelInitializer():
class PipelineModelInitializer():
def __init__(self, config, num_chunks, verbose=False):
self.num_chunks = num_chunks
self.ori_model = build_model(config)
self.layers = self.ori_model.layers_cfg
layer_length = len(self.layers)
self.verbose = verbose
self._logger = get_global_dist_logger()
self._logger = get_dist_logger()
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
def model_initialize(self, partition_method='parameter'):
def initialize(self, partition_method='parameter'):
# Some space for initializing comunication groups
self._interval = None
self._partition_layers(method=partition_method)
@ -198,7 +198,7 @@ class ModelInitializer():
for st, ed in self.parts[stage]:
for idx, layer in enumerate(self.layers[st: ed]):
log_str += f'\t{idx + st:2d}: {layer}\n'
self._logger.info(log_str)
self._logger.info(log_str, ranks=[0])
# Save the partition
self._interval = self.parts[pipeline_rank]

View File

@ -1,4 +1,4 @@
from .collective import all_gather, reduce_scatter, scatter
from .collective import all_gather, reduce_scatter, all_reduce
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
send_backward, send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
@ -6,7 +6,7 @@ from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta
__all__ = [
'all_gather', 'reduce_scatter', 'scatter',
'all_gather', 'reduce_scatter', 'all_reduce',
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward',

View File

@ -11,7 +11,7 @@ from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor:
parallel_mode: ParallelMode, async_op=False) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension.
@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int,
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
shape = list(temp.shape)
shape[dim] *= depth
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
out = list(torch.chunk(out, depth, dim=dim))
out = [val.contiguous() for val in out]
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
out = torch.cat(out, dim=dim)
return out
# shape = list(temp.shape)
# shape[dim] *= depth
# out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
# out = list(torch.chunk(out, depth, dim=dim))
# out = [val.contiguous() for val in out]
shape = [1] * len(tensor.shape)
shape[dim] = depth
out = tensor.repeat(shape)
out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
op = dist.all_gather(tensor_list=out,
tensor=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
# out = torch.cat(out, dim=dim)
if async_op:
return out, op
else:
return out
def reduce_scatter(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor:
parallel_mode: ParallelMode, async_op=False) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.
@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int,
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = list(torch.chunk(tensor, depth, dim=dim))
temp = [val.contiguous() for val in temp]
out = torch.empty(temp[0].shape,
dtype=temp[0].dtype,
device=get_current_device())
dist.reduce_scatter(output=out,
input_list=temp,
group=gpc.get_group(parallel_mode))
return out
# temp = list(torch.chunk(tensor, depth, dim=dim))
# temp = [val.contiguous() for val in temp]
# out = torch.zeros(temp[0].shape,
# dtype=temp[0].dtype,
# device=get_current_device())
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = temp[0].clone()
op = dist.reduce_scatter(output=out,
input_list=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return out, op
else:
return out
def scatter(tensor: Tensor, src: int, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Scatters in a specific dimension from source rank to all ranks in
the parallel group.
def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode,
async_op=False) -> Tensor:
op = dist.all_reduce(tensor,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return tensor, op
else:
return tensor
# def scatter(tensor: Tensor, src: int, dim: int,
# parallel_mode: ParallelMode) -> Tensor:
# """Scatters in a specific dimension from source rank to all ranks in
# the parallel group.
:param tensor: Tensor to be scattered
:param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by scatter
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
rank = gpc.get_local_rank(parallel_mode)
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
return out
# :param tensor: Tensor to be scattered
# :param dim: The dimension scattering in
# :param parallel_mode: Parallel group mode used in this communication
# :type tensor: Tensor
# :type dim: int
# :type parallel_mode: ParallelMode
# :return: The tensor generated by scatter
# :rtype: Tensor
# """
# depth = gpc.get_world_size(parallel_mode)
# temp = tensor.clone()
# dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
# rank = gpc.get_local_rank(parallel_mode)
# out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
# return out

View File

@ -17,8 +17,6 @@ def _communicate(tensor_send_next=None,
recv_next_shape=None,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None,
dtype=None):
"""
Adapted from megatron.p2p_communication.
@ -59,60 +57,44 @@ def _communicate(tensor_send_next=None,
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
if tensor_send_next is not None or recv_next:
if next_rank is None:
next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE)
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
# rank = dist.get_rank()
rank = gpc.get_global_rank()
ops = []
if tensor_send_prev is not None:
send_prev_op = dist.broadcast(tensor_send_prev,
src=rank,
group=up_group,
async_op=True)
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = dist.broadcast(tensor_recv_prev,
src=prev_rank,
group=up_group,
async_op=True)
recv_prev_op = dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)
ops.append(recv_prev_op)
if tensor_recv_next is not None:
recv_next_op = dist.broadcast(tensor_recv_next,
src=next_rank,
group=down_group,
async_op=True)
recv_next_op = dist.P2POp(dist.irecv, tensor_recv_next, next_rank)
ops.append(recv_next_op)
if tensor_send_next is not None:
send_next_op = dist.broadcast(tensor_send_next,
src=rank,
group=down_group,
async_op=True)
send_next_op = dist.P2POp(dist.isend, tensor_send_next, next_rank)
ops.append(send_next_op)
for req in ops:
req.wait()
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
def recv_forward(input_tensor_shape, prev_rank=None):
"""Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_shape: torch.Size
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The input tensor in forward step
:rtype: Tensor
"""
@ -121,20 +103,17 @@ def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
else:
input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
up_group=up_group)
prev_rank=prev_rank)
return input_tensor
def recv_backward(output_grad_shape, next_rank=None, down_group=None):
def recv_backward(output_grad_shape, next_rank=None):
"""Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved
:param next_rank: The rank of the source of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_grad_shape: torch.Size
:type next_rank: int, optional
:type down_group: ProcessGroup, optional
:return: The grad of output tensor in forward step
:rtype: Tensor
"""
@ -143,56 +122,44 @@ def recv_backward(output_grad_shape, next_rank=None, down_group=None):
else:
_, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
down_group=down_group)
next_rank=next_rank)
return output_tensor_grad
def send_forward(output_tensor,
next_rank=None,
down_group=None):
def send_forward(output_tensor, next_rank=None):
"""Sends the input tensor to the next member in pipeline.
:param output_tensor: Tensor to be sent
:param next_rank: The rank of the recipient of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_tensor: Tensor
:type next_rank: int, optional
:type down_group: ProcessGroup, optional
"""
if not gpc.is_last_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_next=output_tensor,
next_rank=next_rank,
down_group=down_group)
next_rank=next_rank)
def send_backward(input_tensor_grad,
prev_rank=None,
up_group=None):
def send_backward(input_tensor_grad, prev_rank=None):
"""Sends the grad tensor to the previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param prev_rank: The rank of the recipient of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_grad: Tensor
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
"""
if not gpc.is_first_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank,
up_group=up_group)
prev_rank=prev_rank)
def send_forward_recv_backward(output_tensor,
output_grad_shape,
recv_next=True,
next_rank=None,
down_group=None):
next_rank=None):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.
:param output_tensor: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor
@ -206,20 +173,18 @@ def send_forward_recv_backward(output_tensor,
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
down_group=down_group)
next_rank=next_rank)
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
up_group=None):
prev_rank=None):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor
@ -233,8 +198,7 @@ def send_backward_recv_forward(input_tensor_grad,
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
up_group=up_group)
prev_rank=prev_rank)
return input_tensor
@ -242,13 +206,11 @@ def send_forward_recv_forward(output_tensor,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
next_rank=None):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.
:param output_tensor: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor
@ -260,9 +222,7 @@ def send_forward_recv_forward(output_tensor,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
next_rank=next_rank)
return input_tensor
@ -270,13 +230,11 @@ def send_backward_recv_backward(input_tensor_grad,
output_grad_shape,
recv_next=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
next_rank=None):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor
@ -288,9 +246,7 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
next_rank=next_rank)
return output_tensor_grad
@ -301,13 +257,11 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
next_rank=None):
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.
:param output_tensor: Tensor sent to the next
:param input_tensor_grad: Tensor sent to the previous
:param input_tensor_shape: The shape of the tensor recieved from the previous
@ -327,7 +281,5 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
next_rank=next_rank)
return input_tensor, output_tensor_grad

View File

@ -6,7 +6,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def send_tensor_meta(tensor, need_meta=True, down_group=None):
def send_tensor_meta(tensor, need_meta=True, next_rank=None):
"""Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function
@ -14,31 +14,34 @@ def send_tensor_meta(tensor, need_meta=True, down_group=None):
:param tensor: Tensor to be sent
:param need_meta: If False, meta information won't be sent
:param down_group: Communication group including the next member in pipeline parallel group
:param next_rank: The rank of the next member in pipeline parallel group
:type tensor: Tensor
:type need_meta: bool, optional
:type down_group: ProcessGroup, optional
:type next_rank: int
:return: False
:rtype: bool
"""
if need_meta:
rank = gpc.get_global_rank()
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
dist.broadcast(send_ndims, src=rank, group=down_group)
dist.broadcast(send_shape, src=rank, group=down_group)
ops = [
dist.P2POp(dist.isend, send_ndims, next_rank),
dist.P2POp(dist.isend, send_shape, next_rank)
]
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
return False
def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
def recv_tensor_meta(tensor_shape, prev_rank=None):
"""Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function
@ -46,27 +49,21 @@ def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
:param tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type tensor_shape: torch.Size
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The shape of the tensor to be recieved
:rtype: torch.Size
"""
if tensor_shape is None:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
recv_ndims = torch.empty((), **tensor_kwargs)
dist.broadcast(recv_ndims, src=prev_rank, group=up_group)
dist.recv(recv_ndims, prev_rank)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.broadcast(recv_shape, src=prev_rank, group=up_group)
dist.recv(recv_shape, prev_rank)
tensor_shape = torch.Size(recv_shape)

View File

@ -25,7 +25,11 @@ TESSERACT_DEP = 'TESSERACT_DEP'
# 3D parallel
DEPTH_3D = 'DEPTH_3D'
INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'
# Tensor parallel attributes
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL]
NUM_PARTITIONS = 'num_partitions'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]

View File

@ -1,5 +1,5 @@
from .config import Config
from .config import Config, ConfigException
from .parallel_context import ParallelContext
from .parallel_context import ParallelMode
from .parallel_mode import ParallelMode
from .process_group_initializer import *
from .random import *

View File

@ -1,70 +0,0 @@
import math
def set_parallel_size(obj, config: dict, key: str, attr_name: str):
if key in config:
ele = config[key]
if isinstance(ele, int):
setattr(obj, attr_name, ele)
elif isinstance(ele, dict):
setattr(obj, attr_name, ele['size'])
else:
raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict"
)
def add_tensor_pg(pg_init, mode, size, depth=None):
if mode == '1d':
pg_init.append(dict(
type='Initializer1D',
parallel_size=size
))
elif mode == '2d':
dim = math.floor(math.sqrt(size))
pg_init.append(dict(
type='Initializer2D_Col',
summa_dim=dim
))
pg_init.append(dict(
type='Initializer2D_Row',
summa_dim=dim
))
elif mode == '2.5d':
dim = math.floor(math.sqrt(size // depth))
pg_init.append(dict(
type='Initializer_Tesseract_ROW',
tesseract_dim=dim,
tesseract_dep=depth
))
pg_init.append(dict(
type='Initializer_Tesseract_COL',
tesseract_dim=dim,
tesseract_dep=depth
))
pg_init.append(dict(
type='Initializer_Tesseract_DEP',
tesseract_dim=dim,
tesseract_dep=depth
))
pg_init.append(dict(
type='Initializer_Tesseract_XZ',
tesseract_dim=dim,
tesseract_dep=depth
))
elif mode == '3d':
dim = math.floor(math.pow(size, 1.0 / 3.0) + 0.5)
pg_init.append(dict(
type='ParallelInitializer3D_Input',
depth=dim
))
pg_init.append(dict(
type='ParallelInitializer3D_Weight',
depth=dim
))
pg_init.append(dict(
type='ParallelInitializer3D_Output',
depth=dim
))
else:
raise NotImplementedError("This kind of tensor splitting has not been implemented yet")

View File

@ -97,3 +97,7 @@ class Config(dict):
sys.path.pop(0)
return config
class ConfigException(Exception):
pass

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import random
from typing import Union
@ -11,8 +10,8 @@ import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config
from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER
from ._utils import set_parallel_size
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
@ -21,11 +20,24 @@ class ParallelContext:
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
:param args: The distributed arguments in the system
:type args: dict
"""
def __init__(self, args=None):
__instance = None
@staticmethod
def get_instance():
if ParallelContext.__instance is None:
ParallelContext()
return ParallelContext.__instance
def __init__(self):
# create a singleton instance
if ParallelContext.__instance is not None:
raise Exception(
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
else:
ParallelContext.__instance = self
# distributed settings
self._global_ranks = dict()
self._local_ranks = dict()
@ -34,7 +46,6 @@ class ParallelContext:
self._ranks_in_group = dict()
# load config from file
self._dist_args = args
self._config = None
# default 3D parallel args, will be overwritten during process group intialization
@ -43,10 +54,22 @@ class ParallelContext:
self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1
# logging
self._verbose = False
self._logger = get_dist_logger()
@property
def config(self):
return self._config
@property
def verbose(self):
return self._verbose
@verbose.setter
def verbose(self, verbose_: bool):
self._verbose = verbose_
def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file.
@ -62,14 +85,6 @@ class ParallelContext:
else:
raise TypeError("Invalid type for config, only dictionary or string is supported")
def set_dist_args(self, args):
"""Sets the distributed arguments.
:param args: The distributed arguments in the system
:type args: dict
"""
self._dist_args = args
@staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode)
@ -268,32 +283,36 @@ class ParallelContext:
self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self, addr=None, port=None):
"""Initializes the global distributed environment.
:param addr: The IP address of the current device
:type addr: str, optional
:param port: The port to be used in the system of the current device
:type port: int, optional
def init_global_dist(self,
rank: int,
world_size: int,
backend: str,
host: str,
port: int
):
"""Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int
:param world_size: world size of the default process group
:type world_size: int
:param host: the master address for distributed training
:type host: str
:param port: the master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
"""
# get config
rank = self._dist_args.local_rank
world_size = self._dist_args.world_size
# default env config, overwrite by exporting
# them in your bash script
addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr
port = os.getenv('MASTER_PORT', '8008') if port is None else port
init_method = f'tcp://{addr}:{port}'
dist.init_process_group(backend=self._dist_args.backend,
rank=rank,
# initialize the default process group
init_method = f'tcp://{host}:{port}'
dist.init_process_group(rank=rank,
world_size=world_size,
backend=backend,
init_method=init_method)
# None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None,
list(range(world_size)), ParallelMode.GLOBAL)
self._global_ranks[ParallelMode.GLOBAL] = rank
self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size,
process_group, ranks_in_group, mode):
@ -312,7 +331,20 @@ class ParallelContext:
pps = self.pipeline_parallel_size
tps = self.tensor_parallel_size
ws = self.world_size
assert ws == dps * pps * tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
assert ws == dps * pps * \
tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config:
ele = config[key]
if isinstance(ele, int):
setattr(self, attr_name, ele)
elif isinstance(ele, dict):
setattr(self, attr_name, ele['size'])
else:
raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict"
)
def init_parallel_groups(self):
"""Initializes the parallel groups.
@ -325,21 +357,20 @@ class ParallelContext:
world_size = self.get_world_size(ParallelMode.GLOBAL)
self.world_size = world_size
assert hasattr(self.config, 'parallel'), 'Expected the field parallel to be present in the config file'
# set parallel size as attributes for global context
parallel_config = self.config.parallel
set_parallel_size(self, parallel_config, 'pipeline',
'pipeline_parallel_size')
set_parallel_size(self, parallel_config, 'tensor',
'tensor_parallel_size')
parallel_config = self.config.get('parallel', None)
if parallel_config is not None:
self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size')
self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size')
# the user should not set the data parallel size manually
# instead, it should be calculated based on other parallel config
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
# get the tensor parallel mode and check
tensor_parallel_mode = parallel_config['tensor'].get('mode', None)
tensor_parallel_mode = None
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
self.check_sanity()
@ -400,23 +431,21 @@ class ParallelContext:
# destroy global process group
dist.destroy_process_group()
def set_device(self):
def set_device(self, device_ordinal: int = None):
"""Sets distributed processes to be bound to devices.
"""
devices_per_node = torch.cuda.device_count()
global_rank = self.get_global_rank()
device = global_rank % devices_per_node
torch.cuda.set_device(device)
print(f'process rank {global_rank} is bound to device {device}')
if device_ordinal is None:
devices_per_node = torch.cuda.device_count()
device_ordinal = global_rank % devices_per_node
def set_seed(self):
torch.cuda.set_device(device_ordinal)
if self._verbose:
self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}')
def set_seed(self, seed: int):
"""Sets seeds for all random libraries.
"""
if hasattr(self.config, 'seed'):
seed = getattr(self.config, 'seed')
else:
seed = 2 # default seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@ -444,11 +473,18 @@ class ParallelContext:
seeds = get_seeds()
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
print(f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.", flush=True)
if self._verbose:
self._logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.",
ranks=[0])
else:
print(f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, pytorch: {seed}", flush=True)
print('WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
flush=True)
if self._verbose:
self._logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, pytorch: {seed}",
ranks=[0])
self._logger.info(
'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
ranks=[0])

View File

@ -4,7 +4,6 @@
import torch.distributed as dist
from colossalai.context import Config
from colossalai.core import global_context as gpc
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode

View File

@ -8,7 +8,6 @@ import torch.distributed as dist
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
from colossalai.context import Config
from colossalai.core import global_context as gpc
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@ -42,8 +41,6 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
tesseract_dep: int,
*args):
super(Initializer_2p5D_ROW, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
@ -66,7 +63,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
for j in range(self.tesseract_dim):
for k in range(self.tesseract_dep):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for i in range(self.tesseract_dim)]
j + self.tesseract_dim * k) for i in range(self.tesseract_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:
@ -81,13 +78,12 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
class Initializer_2p5D_Col(ProcessGroupInitializer):
'''2p5d tensor parallel initialization among cols.
'''
def __init__(self,
tesseract_dim: int,
tesseract_dep: int,
*args):
super(Initializer_2p5D_Col, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
@ -110,7 +106,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
for i in range(self.tesseract_dim):
for k in range(self.tesseract_dep):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for j in range(self.tesseract_dim)]
j + self.tesseract_dim * k) for j in range(self.tesseract_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:
@ -125,13 +121,12 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
class Initializer_2p5D_Dep(ProcessGroupInitializer):
'''2p5D tensor parallel initialization among depths.
'''
def __init__(self,
tesseract_dim: int,
tesseract_dep: int,
*args):
super(Initializer_2p5D_Dep, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
@ -154,7 +149,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
for i in range(self.tesseract_dim):
for j in range(self.tesseract_dim):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for k in range(self.tesseract_dep)]
j + self.tesseract_dim * k) for k in range(self.tesseract_dep)]
group = dist.new_group(ranks)
if self.rank in ranks:
@ -170,13 +165,12 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
class Initializer_2p5D_XZ(ProcessGroupInitializer):
'''2p5d tensor parallel initialization among cols times dep.
'''
def __init__(self,
tesseract_dim: int,
tesseract_dep: int,
*args):
super(Initializer_2p5D_XZ, self).__init__(*args)
self.tensor_parallel_size = gpc.tensor_parallel_size
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
@ -198,8 +192,8 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
for h in range(self.num_group):
for i in range(self.tesseract_dim):
ranks = [h * self.tensor_parallel_size + i + self.tesseract_dim * (
j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in
range(self.tesseract_dim)]
j + self.tesseract_dim * k) for k in range(self.tesseract_dep) for j in
range(self.tesseract_dim)]
group = dist.new_group(ranks)
if self.rank in ranks:

View File

@ -5,7 +5,7 @@ import math
import os
import torch.distributed as dist
from colossalai.constants import DEPTH_3D
from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
@ -18,7 +18,7 @@ def _check_depth_env_var(depth):
if env_depth:
assert int(env_depth) == depth, \
'SUMMA_DIM has been set in the current environment and ' \
'DEPTH_3D has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[DEPTH_3D] = str(depth)
@ -43,6 +43,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT
os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D
for h in range(self.num_group):
for i in range(self.depth):
@ -82,6 +83,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_WEIGHT
os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D
for h in range(self.num_group):
for k in range(self.depth):
@ -121,6 +123,7 @@ class Initializer_3D_Output(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT
os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D
for h in range(self.num_group):
for i in range(self.depth):

View File

@ -3,14 +3,4 @@
from colossalai.context import ParallelContext
global_context = ParallelContext()
def set_global_context(context: ParallelContext):
'''Reset global context to be identical to a given :class:ParallelContext.
:param context: Parallel context to generate our global parallel context.
:type context: ParallelContext
'''
global global_context
global_context = context
global_context = ParallelContext.get_instance()

View File

@ -1,7 +1,5 @@
from ._base_engine import Engine
from .gradient_handler import *
from .schedule import *
from .amp import *
__all__ = ['Engine']

View File

@ -1,17 +1,17 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from typing import List
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.builder import build_gradient_handler
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from .schedule import BaseSchedule
from colossalai.logging import get_dist_logger
from colossalai.utils import is_using_ddp, is_using_pp
from torch import Tensor
class Engine:
@ -20,74 +20,40 @@ class Engine:
It controls a iteration in training.
:param model: The neural network model
:type model: ``torch.nn.Module``
:param optimizer: Optimizer for updating the parameters
:param step_schedule: Running schedule in :meth:`step`
:param gradient_accumulation: Steps of gradient accumulation
:type optimizer: ``torch.optim.Optimizer``
:param criterion: Loss function for calculating loss
:type criterion: ``torch.nn.modules.loss._Loss``
:param gradient_clipping: The norm of gradient clipping
:type model: Module
:type optimizer: Optimizer
:type step_schedule: BaseSchedule, optional
:type gradient_accumulation: int, optional
:type gradient_clipping: float, optional
:param verbose: whether to display log info
:type verbose: bool
"""
def __init__(self,
model: Module,
optimizer: Optimizer,
criterion: _Loss,
step_schedule: BaseSchedule,
gradient_handlers: list = None,
gradient_accumulation: int = 1,
gradient_clipping: float = 0.0,
gradient_handlers: List = None,
clip_grad_norm: float = 0.0,
verbose: bool = True
):
self._model = model
self._optimizer = optimizer
self._criterion = criterion
self._schedule = step_schedule
# schedule initialize
self._schedule.initialize(model, optimizer)
self._clip_grad_norm = clip_grad_norm
self._verbose = verbose
self._logger = get_dist_logger()
# state
self.training = True # default
# gradient accumulation
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
self._grad_accum_size = gradient_accumulation
self._grad_clip = gradient_clipping
self._logger = get_global_dist_logger()
# build gradient handler
self._gradient_handlers = []
if gradient_handlers is not None:
assert isinstance(gradient_handlers, list), \
f'argument gradient_handler_cfg expected type list, ' \
f'but got type {type(gradient_handlers)}'
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handlers = [dict(type='ZeROGradientHandler')]
self._logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
ParallelMode.DATA) > 1:
gradient_handlers = [dict(type='DataParallelGradientHandler')]
self._logger.info(
"Data parallel training is detected, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
if gradient_handlers is None:
self._logger.warning(
"No gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.",
ranks=[0])
if gradient_handlers:
self._gradient_handlers = gradient_handlers
else:
for cfg in gradient_handlers:
handler = build_gradient_handler(cfg, model, optimizer)
self._gradient_handlers.append(handler)
self._gradient_handlers = []
@property
def model(self):
@ -105,11 +71,27 @@ class Engine:
def schedule(self):
return self._schedule
@property
def gradient_accumulation(self):
return self._grad_accum_size
def zero_grad(self):
self.optimizer.zero_grad()
def handle_gradient(self):
def step(self):
self._all_reduce_gradients()
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
self.optimizer.step()
def backward(self, loss: Tensor):
return self.optimizer.backward(loss)
def backward_by_grad(self, tensor, grad):
return self.optimizer.backward_by_grad(tensor, grad)
def calc_loss(self, *args, **kwargs):
return self.criterion(*args, **kwargs)
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
def _all_reduce_gradients(self):
"""Handles all-reduce operations of gradients across different parallel groups.
"""
for handler in self._gradient_handlers:
@ -126,51 +108,3 @@ class Engine:
"""
self.training = False
self._model.eval()
def step(self,
data_iter,
is_last_iteration: bool = False,
return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or
evaluation over a batch of dataset.
:param data_iter: Data iterator of the dataset
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
:param return_loss: loss will be returned if True
:type data_iter: Iterator
:type is_last_iteration: bool, optional
:type return_loss: bool, optional
:return: (output, lablel, loss)
"""
if self.training:
self._optimizer.zero_grad()
# differentiate training and eval with grad accum
if self.training:
for i in range(self._grad_accum_size):
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=False,
grad_accum_size=self._grad_accum_size,
return_loss=return_loss)
if i == self._grad_accum_size - 1:
# all reduce gradients
self.handle_gradient()
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
else:
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=True,
grad_accum_size=1,
return_loss=return_loss)
# consume the remaining dataset left out due to gradient accumulation
if is_last_iteration:
while True:
try:
_ = next(data_iter)
except StopIteration:
break
return output, label, loss

View File

@ -1,2 +0,0 @@
from .grad_scaler import GradScaler
from .amp_type import AMP_TYPE

View File

@ -1,5 +1,5 @@
from ._base_schedule import BaseSchedule
from ._no_pipeline import NoPipelineSchedule
from ._pipeline import PipelineSchedule
from ._pipeline_schedule import PipelineSchedule
from ._non_pipeline_schedule import NonPipelineSchedule
__all__ = ['BaseSchedule', 'NoPipelineSchedule', 'PipelineSchedule']
__all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule']

View File

@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
import torch
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from torch import Tensor
from typing import Iterable, Union, List, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
@ -18,8 +20,9 @@ class BaseSchedule(ABC):
control of FP16 in class schedule.
"""
def __init__(self):
self.logger = get_global_dist_logger()
def __init__(self, batch_data_process_func: Callable = None):
self.logger = get_dist_logger()
self.batch_data_process_func = batch_data_process_func
@staticmethod
def _move_tensor(element):
@ -35,6 +38,11 @@ class BaseSchedule(ABC):
data = data.to(get_current_device()).detach()
return data
def _to_list(self, data):
if torch.is_tensor(data):
return [data]
return data
def load_batch(self, data_iter):
"""Loads a batch from data iterator. It returns the data and labels which are
already in the same GPU as where the model's.
@ -44,46 +52,34 @@ class BaseSchedule(ABC):
"""
if data_iter is None:
raise RuntimeError('Dataloader is not defined.')
data, label = next(data_iter)
batch_data = next(data_iter)
if self.batch_data_process_func:
data, label = self.batch_data_process_func(batch_data)
else:
data, label = batch_data
data, label = self._to_list(data), self._to_list(label)
return self._move_to_device(data), self._move_to_device(label)
def initialize(self, model, optimizer):
"""Initializes the model and the optimizer before training.
This is often used in FP16 training.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
def pre_processing(self, engine: Engine):
"""To perform actions before running the schedule.
"""
return model, optimizer
pass
@abstractmethod
def forward_backward_step(self,
data_iter,
model,
criterion,
optimizer=None,
forward_only=False,
grad_accum_size: int = 1,
return_loss=True):
engine: Engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True
):
"""The process function over a batch of dataset for training or evaluation.
:param data_iter: Data iterator of the dataset
:param model: Model used in training or evaluation
:param optimizer: Optimizer used in training or evaluation
:param criterion: Loss function
:param engine: Colossalai training engine
:param inputs: input data
:param labels: ground truth
:param forward_only: If True, the process won't include backward
:param grad_accum_size: Steps of gradient accumulation
:param return_loss: If False, the loss won't be returned
"""
pass
@abstractmethod
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
"""Updates the parameters with the optimizer.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
:param grad_clipping: The norm of gradient clipping
:type grad_clipping: float, optional
"""
pass
pass

View File

@ -1,188 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
try:
import apex.amp as apex_amp
except:
pass
try:
import torch.cuda.amp as torch_amp
except:
pass
from typing import Iterable
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16, convert_to_fp32
from ..amp import AMP_TYPE, GradScaler
class NoPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(
self,
amp_type: AMP_TYPE = None,
amp_config: dict = None,
):
super().__init__()
# mixed precision training
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
'unrecognised value for argument fp16, it can only be None, torch or apex'
self.use_zero_level_2_3 = False
if amp_type is not None:
self.fp16 = True
self.amp_type = amp_type
if amp_config is not None:
assert isinstance(amp_config, dict), \
f'expected argument fp16_config to be type dictionary, but got {type(amp_config)}'
if self.amp_type == AMP_TYPE.TORCH:
# torch apex
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.APEX:
# apex amp
if amp_config is None:
amp_config = dict(opt_level='O2')
self.logger.warning(
'apex is deprecated, please consider using torch.cuda.amp instead.'
)
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.PARALLEL:
# use fp16 optimizer for tensor parallel training
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
else:
self.fp16 = False
self.amp_type = None
def initialize(self, model: nn.Module, optimizer: Optimizer):
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True
assert self.amp_type != AMP_TYPE.PARALLEL, \
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
if self.fp16:
if self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
elif self.amp_type == AMP_TYPE.APEX:
model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
return model, optimizer
def forward_backward_step(self,
data_iter: Iterable,
model: nn.Module,
criterion: nn.modules.loss._Loss,
optimizer: Optimizer = None,
forward_only: bool = False,
grad_accum_size: int = 1,
return_loss: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param model: Model for training and inference
:param criterion: Loss function for training
:param optimizer: Optimizer used for training
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param grad_accum_size: The number of iterations for gradient accumulation
:param return_loss: Loss will be returned if True
:type data_iter: Iterator
:type model: torch.nn.Module
:type criterion: torch.nn.modules.loss._Loss
:type optimizer: torch.optim.Optimizer
:type forward_only: bool, optional
:type grad_accum_size: int
:type return_loss: bool, optional
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
data, label = self.load_batch(data_iter)
loss = None
# forward
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
with torch_amp.autocast():
output = model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = criterion(*output, *label)
else:
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
data = convert_to_fp16(data)
output = model(*data)
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
output = convert_to_fp32(output)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = criterion(*output, *label)
loss /= grad_accum_size
if not forward_only:
# backward
if self.use_zero_level_2_3:
optimizer.backward(loss)
elif self.fp16:
if self.amp_type == AMP_TYPE.APEX:
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
elif self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.scale(loss).backward()
elif self.amp_type == AMP_TYPE.PARALLEL:
loss = optimizer.scale_loss(loss)
loss.backward()
# scale back to display the original value in logs
loss.div_(optimizer.grad_scaler.scale)
else:
loss.backward()
if return_loss:
return output, label, loss * grad_accum_size
else:
return output, None, None
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
# step optimizer
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
if grad_clipping > 0.0:
self._torch_amp_scaler.unscale_(optimizer)
clip_grad_norm_fp32(model.parameters(), grad_clipping)
self._torch_amp_scaler.step(optimizer)
self._torch_amp_scaler.update()
else:
if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
clip_grad_norm_fp32(model.parameters(), grad_clipping)
optimizer.step()

View File

@ -0,0 +1,61 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Iterable
import torch
import torch.nn as nn
from colossalai.engine import Engine
from torch.optim import Optimizer
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context
class NonPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def forward_backward_step(self,
engine: Engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:param engine: Model for training and inference
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param return_loss: Loss will be returned if True
:type engine: Iterator
:type data_iter: Iterator
:type forward_only: bool, optional
:type return_loss: bool, optional
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
data, label = self.load_batch(data_iter)
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
output = engine(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = engine.criterion(*output, *label)
if not forward_only:
engine.backward(loss)
if return_loss:
return output, label, loss
else:
return output, None, None

View File

@ -10,12 +10,12 @@ from torch import Tensor
from colossalai.communication import *
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16
from ..amp import AMP_TYPE
from colossalai.amp import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]):
@ -28,32 +28,25 @@ def squeeze(x: Union[Tensor, tuple, list]):
class PipelineSchedule(BaseSchedule):
"""A helper schedule class for pipeline parallelism running environment.
It uses non-interleaved 1F1B strategy. Other properties are similar as
:class:`NoPipelineSchedule`.
:class:`NonPipelineSchedule`.
:param num_microbatches: The number of microbatches
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:param sync_data: If set to `True`, will sync data every batch over pipeline stages
:type num_microbatches: int
:type amp_type: AMP_TYPE
:type amp_config: dict
:type sync_data: bool
"""
def __init__(self,
num_microbatches,
amp_type: AMP_TYPE = None,
amp_config: dict = None):
sync_data: bool = True):
super().__init__()
self.num_microbatches = num_microbatches
self.data_sync = True # close after making sure data is identical
# amp
# LSGL: amp_config is not used, but leave here for future extension
self.amp_type = amp_type
self.amp_config = amp_config
if self.amp_type is not None:
assert self.amp_type == AMP_TYPE.PARALLEL, 'We only support AMP_TYPE.PARALLEL for pipeline training for now'
self.sync_data = sync_data
def _move_to_device(self, data):
if isinstance(data, (
@ -67,30 +60,37 @@ class PipelineSchedule(BaseSchedule):
return data
def _sync_data(self):
reqs = []
if gpc.is_first_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_global_rank()
dist.broadcast(
reqs.append(dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
)
dist.broadcast(
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
async_op=True
))
reqs.append(dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
)
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
async_op=True
))
if gpc.is_last_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
dist.broadcast(
reqs.append(dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
)
dist.broadcast(
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
async_op=True
))
reqs.append(dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
)
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
async_op=True
))
for req in reqs:
req.wait()
# Pipeline schedule just puts data in memory
def load_batch(self, data_iter):
@ -104,7 +104,7 @@ class PipelineSchedule(BaseSchedule):
assert batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = batch_size // self.num_microbatches
if self.data_sync:
if self.sync_data:
self._sync_data()
def _get_data_slice(self, tensor):
@ -116,21 +116,20 @@ class PipelineSchedule(BaseSchedule):
self.batch_pos += self.microbatch_size
return (data,), (label,)
def initialize(self, model, optimizer):
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
def pre_processing(self, engine):
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
# LSG: set default dtype to fp16 for communication
if self.amp_type == AMP_TYPE.PARALLEL:
if isinstance(engine.model, NaiveAMPModel):
torch.set_default_dtype(torch.half)
self.logger.info(
self.logger.warning(
'default tensor dtype is set to torch.half for fp16 training',
ranks=[0])
def forward_step(self, model, criterion, input_tensor, return_tensors,
grad_accum_size, return_loss=True):
def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
@ -138,17 +137,16 @@ class PipelineSchedule(BaseSchedule):
if input_tensor is None:
input_tensor, label = self.load_micro_batch()
if self.amp_type == AMP_TYPE.PARALLEL:
input_tensor = convert_to_fp16(input_tensor)
input_tensor = squeeze(input_tensor)
output_tensor = model(input_tensor)
output_tensor = engine(input_tensor)
output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = criterion(output_tensor, *label) \
/ (self.num_microbatches * grad_accum_size)
loss_reduced = engine.criterion(output_tensor, *label) \
/ self.num_microbatches
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
@ -159,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
else:
return output_tensor
def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
@ -171,9 +169,10 @@ class PipelineSchedule(BaseSchedule):
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
if output_tensor_grad is None:
engine.backward(output_tensor)
else:
engine.backward_by_grad(output_tensor, output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
@ -183,12 +182,9 @@ class PipelineSchedule(BaseSchedule):
return input_tensor_grad
def forward_backward_step(self,
engine,
data_iter,
model,
criterion,
optimizer=None,
forward_only=False,
grad_accum_size: int = 1,
return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -226,9 +222,8 @@ class PipelineSchedule(BaseSchedule):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
output_tensor = self.forward_step(
model, criterion,
input_tensor, return_tensors,
grad_accum_size, return_loss=return_loss
engine, input_tensor, return_tensors,
return_loss=return_loss
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
@ -252,9 +247,8 @@ class PipelineSchedule(BaseSchedule):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(
model, criterion,
input_tensor, return_tensors,
grad_accum_size, return_loss=return_loss
engine, input_tensor, return_tensors,
return_loss=return_loss
)
if forward_only:
send_forward(output_tensor)
@ -276,7 +270,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(
optimizer,
engine,
input_tensor, output_tensor,
output_tensor_grad
)
@ -297,7 +291,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor_grad = recv_backward(bt_shape)
input_tensor_grad = self.backward_step(
optimizer,
engine,
input_tensor, output_tensor,
output_tensor_grad
)
@ -309,11 +303,8 @@ class PipelineSchedule(BaseSchedule):
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
sum(loss) * grad_accum_size)
sum(loss))
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:
return tuple((None, None, None))
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
optimizer.step()

View File

@ -1,27 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union, List
from torch import Tensor
def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.half()
elif isinstance(data, (list, tuple)):
ret = [val.half() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.float()
elif isinstance(data, (list, tuple)):
ret = [val.float() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret

View File

@ -3,377 +3,326 @@
import argparse
import pprint
import random
from pathlib import Path
from typing import Callable, Iterable, Optional, Union
from typing import Tuple
import os
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
from pathlib import Path
from typing import Iterable, Union, Optional, Tuple, List, Dict
from colossalai.amp import convert_to_amp, AMP_TYPE
from colossalai.context import Config, ParallelMode, ConfigException
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
from colossalai.nn import DataParallelSampler
from colossalai.nn.model.base_model import BaseModel
from .builder import (ModelInitializer, build_dataset, build_loss,
build_model, build_optimizer,
build_optimizer_wrapper, build_schedule)
from .context import Config, ParallelMode
from .core import global_context as gpc
from .utils import get_current_device, sync_model_param_in_dp
from colossalai.logging import get_dist_logger
from colossalai.utils import (accumulate_gradient, get_current_device,
sync_model_param_in_dp, is_using_ddp, is_using_pp)
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
from colossalai.builder.builder import build_gradient_handler
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel as DDP
def parse_args():
def get_default_parser():
'''Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
:return: call the parse arguments function
:return: returns the parser with the default arguments, the user may add customized arguments into this parser
:rtype: Namespace
'''
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file')
parser.add_argument('--host',
type=str,
default=None,
help='the master address for distributed training')
parser.add_argument('--port',
type=str,
default=None,
type=int,
help='the master port for distributed training')
parser.add_argument('--world_size', type=int, help='world size for ')
parser.add_argument('--world_size', type=int, help='world size for distributed training')
parser.add_argument('--rank', type=int, help='rank for the default process group')
parser.add_argument('--local_rank',
type=int,
help='rank for the default process group')
help='local rank on the node')
parser.add_argument('--backend',
type=str,
default='nccl',
help='backend for torch.distributed')
return parser.parse_args()
help='backend for distributed communication')
return parser
def init_dist(config: Union[str, dict] = None,
local_rank: int = None,
world_size: int = None,
host: str = None,
port: str = None,
backend: str = None):
def launch(config: Union[str, Path, Config, Dict],
rank: int,
world_size: int,
host: str,
port: int,
backend: str = 'nccl',
local_rank: int = None,
seed: int = 1024,
verbose: bool = True):
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
Then initialize and set distributed environment by calling global_context's functions.
Then initialize and set distributed environment by calling global_context's functions.
:param config: config file or config file path are both acceptable
:type config: Union[str, dict], optional
:param local_rank: rank for the default process group, defaults to None
:type config: Union[str, dict, Config]
:param rank: rank for the default process group
:type rank: int
:param world_size: world size of the default process group
:type world_size: int
:param host: the master address for distributed training
:type host: str
:param port: the master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
:param local_rank: rank for the process on the node and is used to set the default CUDA device,
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically
:type local_rank: int, optional
:param world_size: world size of GPUs, defaults to None
:type world_size: int, optional
:param host: the master address for distributed training, defaults to None
:type host: str, optional
:param port: the master port for distributed training, defaults to None
:type port: str, optional
:param backend: backend for torch.distributed, defaults to None
:type backend: str, optional
:raises Exception: raise exception when config type is wrong
'''
args = [config, local_rank, world_size, host, port, backend]
arg_given = [arg is not None for arg in args]
if not all(arg_given):
args = parse_args()
if config is None:
config = args.config
if local_rank is None:
local_rank = args.local_rank
if world_size is None:
world_size = args.world_size
if host is None:
host = args.host
if port is None:
port = args.port
if backend is None:
backend = args.backend
args = Config(
dict(config=config,
host=host,
port=port,
world_size=world_size,
local_rank=local_rank,
backend=backend))
# set distributed settings
dist_args = Config(
dict(local_rank=args.local_rank,
world_size=args.world_size,
backend=args.backend))
gpc.set_dist_args(dist_args)
gpc.verbose = verbose
# set config
if isinstance(args.config, dict):
cfg = args.config
elif isinstance(args.config, (str, Path)):
cfg = Config.from_file(args.config)
else:
raise Exception('Config type error: {}'.format(type(args.config)))
gpc.load_config(cfg)
assert isinstance(config, (Config, str, Path, dict)), \
f'expected argument config to be Config, str or Path, but got {type(config)}'
if not isinstance(config, Config) and isinstance(config, dict):
config = Config(config)
if isinstance(config, (str, Path)):
config = Config.from_file(config)
gpc.load_config(config)
# init dist groups
gpc.init_global_dist(args.host, args.port)
# init default process group
gpc.init_global_dist(rank, world_size, backend, host, port)
# init process groups for different parallel modes from config
gpc.init_parallel_groups()
# init dist logger
init_global_dist_logger()
# set cuda device
if torch.cuda.is_available():
gpc.set_device()
# if local rank is not given, calculate automatically
gpc.set_device(local_rank)
gpc.set_seed(seed)
if verbose:
logger = get_dist_logger()
logger.info(f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs):
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
.. note: when pipeline parallel is enabled, shuffle cannot be True
as it will result in mismatch between input data on the 1st
stage and label on the last stage
:param dataset: a :class:utils.data.dataset dataset
:param seed: random worker seed, defaults to 1024
:type seed: int, optional
:param add_sampler_if_possible: [description], defaults to False
:type add_sampler_if_possible: bool, optional
:return: a :class:utils.data.dataset dataloader
:rtype: torch.utils.data.dataset
'''
_kwargs = kwargs.copy()
if 'shuffle' in _kwargs:
shuffle = _kwargs.pop('shuffle')
else:
shuffle = False
if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
sampler = DataParallelSampler(dataset, shuffle=shuffle)
else:
sampler = None
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
if sampler is None:
return DataLoader(dataset,
worker_init_fn=seed_worker,
shuffle=shuffle,
**_kwargs)
else:
return DataLoader(dataset,
sampler=sampler,
worker_init_fn=seed_worker,
**_kwargs)
def launch_from_slurm(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NPROCS'])
launch(config=config,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def initialize(config: Union[str, dict] = None,
local_rank: int = None,
world_size: int = None,
host: str = None,
port: str = None,
backend: str = None,
train_dataloader: Optional[Union[Iterable, Callable]] = None,
test_dataloader: Optional[Union[Iterable, Callable]] = None,
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
launch(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def launch_from_torch(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
launch(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def initialize(model: Union[nn.Module, List[nn.Module]],
optimizer: Union[Optimizer, List[Optimizer]],
criterion: Union[_Loss, List[_Loss]],
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None,
verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader]:
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
:param config: config file or config file path are both acceptable
:type config: Union[str, dict], optional
:param local_rank: rank for the default process group, defaults to None
:type local_rank: int, optional
:param world_size: world size of GPUs, defaults to None
:type world_size: int, optional
:param host: the master address for distributed training, defaults to None
:type host: str, optional
:param port: the master port for distributed training, defaults to None
:type port: str, optional
:param backend: backend for torch.distributed, defaults to None
:type backend: str, optional
:param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
:return: (engine, train_dataloader, test_dataloader, criterion)
:param model: your model instance
:type model: a single or a list of ``torch.nn.Module`` objects
:param optimizer: your optimizer instance
:type optimizer: a single or a list of ``torch.optim.optimizer.Optimizer`` objects
:param criterion: your criterion instance
:type criterion: a single or a list of ``torch.nn.modules.loss._Loss`` objects
:param train_dataloader: dataloaders for training data
:type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
:param train_dataloader: dataloaders for testing data
:type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
:return: (engine, criterion, train_dataloader, test_dataloader)
:rtype: tuple
'''
# initialize distributed environment
init_dist(config=config,
local_rank=local_rank,
world_size=world_size,
host=host,
port=port,
backend=backend)
# get logger
logger = get_dist_logger()
gpc.verbose = verbose
# init logger
logger = get_global_dist_logger()
logger.info(f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
# get config from gpc
config = gpc.config
# print config
logger.info(f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================", ranks=[0])
if verbose:
logger.info(f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================\n", ranks=[0])
# cudnn
cudnn_benchmark = gpc.config.get('cudnn_benchmark', True)
cudnn_deterministic = gpc.config.get('cudnn_deterministic', False)
cudnn_benchmark = config.get('cudnn_benchmark', True)
cudnn_deterministic = config.get('cudnn_deterministic', False)
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.deterministic = cudnn_deterministic
logger.info(
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
if verbose:
logger.info(
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# set seed, cuda seed is only set when cuda is avail
gpc.set_seed()
# return_items = list()
# check fp16 and zero
should_convert_model_to_half = False
should_wrap_fp16_optimizer = False
should_wrap_zero_optimizer_level_2_3 = False
if hasattr(gpc.config, 'fp16'):
fp16_mode = gpc.config.fp16.mode
if fp16_mode == AMP_TYPE.PARALLEL:
should_convert_model_to_half = True
should_wrap_fp16_optimizer = True
if hasattr(gpc.config, 'zero'):
should_wrap_zero_optimizer_level_2_3 = True
zero_type = gpc.config.zero.type
if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']:
should_convert_model_to_half = True
assert not should_wrap_fp16_optimizer, \
'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3'
# build model
logger.info('Building model ...', ranks=[0])
assert hasattr(
gpc.config, 'model'), "Build error: configuration 'model' is missing"
if gpc.pipeline_parallel_size > 1:
model = ModelInitializer(gpc.config.model, 1, verbose=True)
model = model.model_initialize()
else:
model = build_model(gpc.config.model)
if isinstance(model, BaseModel):
model.build_from_cfg()
model = model.to(get_current_device())
# first sync model across dp ranks
model.to(get_current_device())
sync_model_param_in_dp(model)
logger.info('Model is created', ranks=[0])
if should_convert_model_to_half:
model = model.half()
logger.info("Model is cast to fp16", ranks=[0])
# check amp and zero
fp16_cfg = gpc.config.get('fp16', None)
zero_cfg = gpc.config.get('zero', None)
# training data
if callable(train_dataloader):
logger.info(
f'Build train data loader from {train_dataloader}', ranks=[0])
train_dataloader = train_dataloader()
if train_dataloader is None and hasattr(gpc.config, 'train_data'):
logger.info('Preparing data ...', ranks=[0])
# assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing."
train_dataset = build_dataset(gpc.config.train_data.dataset)
logger.info('Train dataset is ready.', ranks=[0])
if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
raise ConfigException(
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
train_dataloader = get_dataloader(train_dataset,
gpc.config.get('seed', 1024),
True,
**gpc.config.train_data.dataloader,
)
logger.info(
f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0])
# initialize amp
amp_mode = None
if fp16_cfg is not None and fp16_cfg.mode is not None:
cfg_ = fp16_cfg.copy()
amp_mode = cfg_.pop('mode')
model, optimizer, criterion = convert_to_amp(model=model,
optimizer=optimizer,
criterion=criterion,
mode=amp_mode,
amp_config=cfg_)
if callable(test_dataloader):
logger.info(
f'Build test data loader from {test_dataloader}', ranks=[0])
test_dataloader = test_dataloader()
# testing data, allowed to be None
if test_dataloader is None and hasattr(gpc.config, 'test_data'):
test_dataset = build_dataset(gpc.config.test_data.dataset)
test_dataloader = get_dataloader(
test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader)
logger.info(
f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0])
if zero_cfg is not None:
cfg_ = zero_cfg.copy()
level = cfg_.pop('level')
model, optimizer = convert_to_zero(model=model,
optimizer=optimizer,
level=level,
zero_config=cfg_
)
# build loss function
assert hasattr(gpc.config, 'loss'), \
'Build error: configuration \'loss\' is missing.'
criterion = build_loss(gpc.config.loss)
logger.info('Loss function is created', ranks=[0])
# build optimizer
assert hasattr(gpc.config, 'optimizer'), \
"Build error: configuration 'optimizer' is missing."
optim_type = gpc.config.optimizer.type
is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer'
if is_pytorch_native_zero_level_1:
original_cfg_copy = gpc.config.optimizer.copy()
original_cfg_copy.pop('type')
cfg = dict(type=optim_type, process_group=gpc.get_group(
ParallelMode.DATA), **original_cfg_copy)
optimizer = build_optimizer(cfg, model)
# gradient handler
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
if gradient_handler_cfg is None:
# if gradient handler is not specified in the configuration file,
# check in the following order
# 1. if optimizer is ZERO, then use zero grad handler
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
if verbose:
logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
elif is_using_ddp():
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
if verbose:
logger.info(
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
else:
optimizer = build_optimizer(gpc.config.optimizer, model)
if not isinstance(gradient_handler_cfg, list):
raise ConfigException(
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
if should_wrap_zero_optimizer_level_2_3:
optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model)
if should_wrap_fp16_optimizer:
# replace the field mode with type
fp16_cfg = gpc.config.fp16.copy()
amp_type = fp16_cfg.pop('mode')
assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL'
fp16_cfg['type'] = 'FP16Optimizer'
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
logger.info('Optimizer is created', ranks=[0])
# build schedule and engine
if hasattr(gpc.config, 'fp16'):
amp_type = gpc.config.fp16.mode
amp_cfg = gpc.config.fp16.copy()
amp_cfg.pop('mode')
if gradient_handler_cfg is None:
gradient_handlers = None
if verbose and not isinstance(model, DDP):
logger.warning(
"No PyTorch DDP or gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.",
ranks=[0])
else:
amp_type = None
amp_cfg = None
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
engine_cfg = gpc.config.get('engine', dict())
schedule_cfg = engine_cfg.pop('schedule', None)
# check if optimizer is ColossalaiOptimizer
if not isinstance(optimizer, (ColossalaiOptimizer, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
optimizer = ColossalaiOptimizer(optim=optimizer)
schedule_type = None
if schedule_cfg is not None:
schedule_type = schedule_cfg.get('type', None)
# gradient accumulation
grad_accum_size = gpc.config.get('gradient_accumulation', None)
if grad_accum_size is not None:
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model,
optimizer=optimizer,
dataloader=train_dataloader,
accumulate_size=grad_accum_size,
gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler)
if schedule_type is not None:
# run customized schedule
schedule_cfg['amp_type'] = amp_type
schedule_cfg['amp_config'] = amp_cfg
schedule = build_schedule(schedule_cfg)
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
assert schedule_cfg is not None, \
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
schedule = PipelineSchedule(
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
else:
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
# clip grad norm
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
if clip_grad_norm > 0:
if zero_cfg is not None:
raise ConfigException(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
elif fp16_cfg is not None and fp16_cfg.mode == AMP_TYPE.NAIVE:
raise ConfigException(
"clip_grad_norm should be specified with AMP_TYPE.NAIVE, you should specify clip_grad in fp16 configuration")
engine = Engine(
model=model,
optimizer=optimizer,
criterion=criterion,
step_schedule=schedule,
**gpc.config.get('engine', dict())
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm
)
return engine, train_dataloader, test_dataloader
return engine, train_dataloader, test_dataloader, lr_scheduler

View File

@ -1,26 +1,10 @@
from colossalai.core import global_context as gpc
from .logging import DistributedLogger
__all__ = ['get_global_dist_logger', 'get_dist_logger', 'DistributedLogger', 'init_global_dist_logger']
_GLOBAL_LOGGER: DistributedLogger = None
__all__ = ['get_dist_logger', 'DistributedLogger']
def get_dist_logger(name, level='INFO', root_path: str = None, mode='a'):
return DistributedLogger(name=name, level=level, root_path=root_path, mode=mode)
def get_global_dist_logger():
assert _GLOBAL_LOGGER is not None, 'Global distributed logger is not initialized'
return _GLOBAL_LOGGER
def init_global_dist_logger():
rank = gpc.get_global_rank()
if hasattr(gpc.config, 'logging'):
logger = get_dist_logger(name=f'rank_{rank}', **gpc.config.logging)
else:
logger = get_dist_logger(name=f'rank_{rank}', level='INFO')
global _GLOBAL_LOGGER
assert _GLOBAL_LOGGER is None, 'Global distributed logger has already been initialized'
_GLOBAL_LOGGER = logger
def get_dist_logger(name='root'):
"""Get logger instance based on name. The DistributedLogger will create singleton instances,
which means that only one logger instance is created per name.
"""
return DistributedLogger.get_instance(name=name)

View File

@ -1,11 +1,13 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import colossalai
import logging
from pathlib import Path
from typing import Union
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT)
@ -16,40 +18,92 @@ class DistributedLogger:
:param name: The name of the logger
:type name: str
:param level: The threshold for the logger. Logging messages which are less severe than `level`
will be ignored
:type level: str
:param root_path: The root path where logs are stored
:type root_path: str, optional
:param mode: The mode that the file is opened in. Defaults to 'a'
:type mode: str, optional
"""
def __init__(self, name, level='INFO', root_path: str = None, mode='a'):
self._logger = logging.getLogger(name)
__instances = dict()
@staticmethod
def get_instance(name: str):
"""Get the unique single logger instance based on name.
:param name: The name of the logger
:type name: str
:return: a DistributedLogger object
:rtype: DistributedLogger
"""
if name in DistributedLogger.__instances:
return DistributedLogger.__instances[name]
else:
logger = DistributedLogger(name=name)
return logger
def __init__(self, name):
if name in DistributedLogger.__instances:
raise Exception('Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
else:
self._name = name
self._logger = logging.getLogger(name)
DistributedLogger.__instances[name] = self
@staticmethod
def _check_valid_logging_level(level: str):
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
def set_level(self, level: str):
"""Set the logging level
:param level: can only be INFO, DEBUG, WARNING and ERROR
:type level: str
"""
self._check_valid_logging_level(level)
self._logger.setLevel(getattr(logging, level))
if root_path is not None:
log_root_path = Path(root_path)
# create path if not exists
log_root_path.mkdir(parents=True, exist_ok=True)
log_path = log_root_path.joinpath(f'{name}.log')
file_handler = logging.FileHandler(log_path, mode)
file_handler.setLevel(getattr(logging, level))
formatter = logging.Formatter(_FORMAT)
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
def log_to_file(self,
path: Union[str, Path],
mode: str = 'a',
level: str = 'INFO',
suffix: str = None):
"""Save the logs to file
:param path: the file to save the log
:type path: a string or pathlib.Path object
:param mode: the mode to write log into the file
:type mode: str
:param level: can only be INFO, DEBUG, WARNING and ERROR
:type level: str
"""
assert isinstance(path, (str, Path)), \
f'expected argument path to be type str or Path, but got {type(path)}'
self._check_valid_logging_level(level)
if isinstance(path, str):
path = Path(path)
# set the default file name if path is a directory
if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
rank = 0
else:
rank = colossalai.core.global_context.get_global_rank()
if suffix is not None:
log_file_name = f'rank_{rank}_{suffix}.log'
else:
log_file_name = f'rank_{rank}.log'
path = path.joinpath(log_file_name)
# add file handler
file_handler = logging.FileHandler(path, mode)
file_handler.setLevel(getattr(logging, level))
formatter = logging.Formatter(_FORMAT)
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
if ranks is None:
getattr(self._logger, level)(message)
else:
local_rank = gpc.get_local_rank(parallel_mode)
local_rank = colossalai.core.global_context.get_local_rank(parallel_mode)
if local_rank in ranks:
getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores an info log message.
"""Log an info message.
:param message:
:type message:
@ -61,7 +115,7 @@ class DistributedLogger:
self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores a warning log message.
"""Log a warning message.
:param message: The message to be logged
:type message: str
@ -73,7 +127,7 @@ class DistributedLogger:
self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores a debug log message.
"""Log a debug message.
:param message: The message to be logged
:type message: str
@ -85,7 +139,7 @@ class DistributedLogger:
self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores an error log message.
"""Log an error message.
:param message: The message to be logged
:type message: str

View File

@ -1,4 +1,3 @@
from .data import *
from .layer import *
from .loss import *
from .lr_scheduler import *

View File

@ -1,3 +0,0 @@
from .caltech101_dataset import Caltech101Dataset
from .cifar10_dataset import CIFAR10Dataset
from .sampler import *

View File

@ -1,14 +0,0 @@
import numpy as np
def pil_img_to_numpy(pil_img):
"""convert a PIL image to numpy nd-array
:param pil_img: a PIL image
:type pil_img: PIL.Image
:return: a nd-array
:rtype: numpy.ndarray
"""
np_img = np.array(pil_img)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img

View File

@ -1,17 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from colossalai.builder import build_transform
class BaseDataset(Dataset, ABC):
def __init__(self, transform_pipeline: list):
transform_list = [build_transform(cfg) for cfg in transform_pipeline]
transform = transforms.Compose(transform_list)
self._transform_pipeline = transform

View File

@ -1,43 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import Caltech101
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class Caltech101Dataset(BaseDataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = Caltech101(
transform=self._transform_pipeline, *args, **kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)

View File

@ -1,44 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import CIFAR10
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class CIFAR10Dataset(BaseDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = CIFAR10(transform=self._transform_pipeline,
*args,
**kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)

View File

@ -1,4 +0,0 @@
from .base_sampler import BaseSampler
from .data_parallel_sampler import DataParallelSampler
__all__ = ['BaseSampler', 'DataParallelSampler']

33
colossalai/nn/init.py Normal file
View File

@ -0,0 +1,33 @@
import math
from torch import Tensor
from torch.nn import init as init
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
if init_method == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(tensor, -a, a)
elif init_method == 'jax_embed':
std = math.sqrt(1.0 / fan_in)
init.trunc_normal_(tensor, std=std / .87962566103423978)
elif init_method == 'zero':
init.zeros_(tensor)
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
if init_method == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
init.normal_(tensor, std=1e-6)
elif init_method == 'jax_embed':
init.trunc_normal_(tensor, std=.02)
elif init_method == 'zero':
init.zeros_(tensor)

View File

@ -1,9 +1,8 @@
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .parallel_vision_transformer import *
from .vanilla_resnet import *
from .vanilla_vision_transformer import *
from .non_parallel_layers import *
from .wrapper import *

View File

@ -2,40 +2,14 @@
# -*- encoding: utf-8 -*-
import math
import collections.abc
from itertools import repeat
import numpy as np
from colossalai.utils.common import print_rank_0
import torch
from torch import Tensor
from torch import nn
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.utils import checkpoint
from colossalai.constants import IS_TENSOR_PARALLEL
def divide(numerator, denominator):
""" only allow exact division """
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
return numerator // denominator
def gelu(x: Tensor) -> Tensor:
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
def set_tensor_parallel_attribute(param):
if not hasattr(param, IS_TENSOR_PARALLEL):
setattr(param, IS_TENSOR_PARALLEL, True)
from torch import Tensor, nn
class CheckpointModule(nn.Module):
@ -44,15 +18,15 @@ class CheckpointModule(nn.Module):
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args):
def _forward(self, *args, **kwargs):
raise NotImplementedError(
'CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args):
def forward(self, *args, **kwargs):
if self._use_checkpoint:
return checkpoint(self._forward, *args)
return checkpoint(self._forward, *args, **kwargs)
else:
return self._forward(*args)
return self._forward(*args, **kwargs)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
@ -61,3 +35,38 @@ class CheckpointModule(nn.Module):
def eval(self):
self._use_checkpoint = False
return super().eval()
def divide(numerator, denominator):
""" only allow exact division """
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
return numerator // denominator
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
def set_tensor_parallel_attribute_by_size(param, size):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))
def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, num_partitions)
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)

View File

@ -0,0 +1,35 @@
# adapted from Megatron-LM
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
import torch
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply

View File

@ -0,0 +1,8 @@
from ._vit import (ViTBlock, VanillaViTAttention, VanillaViTBlock, VanillaViTDropPath,
VanillaViTHead, VanillaViTMLP, VanillaViTPatchEmbedding)
__all__ = [
'ViTBlock', 'VanillaViTAttention', 'VanillaViTBlock', 'VanillaViTDropPath',
'VanillaViTHead', 'VanillaViTMLP', 'VanillaViTPatchEmbedding'
]

View File

@ -1,23 +1,47 @@
import collections.abc
from itertools import repeat
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
from .._common_utils import to_2tuple
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
@LAYERS.register_module
class ViTBlock(nn.Module):
"""Vision Transformer block
return parse
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
to_2tuple = _ntuple(2)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@LAYERS.register_module

View File

@ -1,5 +1,11 @@
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D
from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHead
__all__ = [
'Linear1D_Col', 'Linear1D_Row',
'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D',
'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead'
]

View File

@ -0,0 +1,34 @@
import torch
try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None

View File

@ -0,0 +1,220 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide, ACT2FN
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
@LAYERS.register_module
class TransformerMLP1D(ParallelLayer):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super(TransformerMLP1D, self).__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
bias=not skip_bias_add,
dtype=dtype,
gather_output = False,
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
bias=not skip_bias_add,
dtype=dtype,
parallel_input = True,
)
self.dropout = nn.Dropout(dropout_prob)
# self.layernorm = LayerNorm1D(in_features, dtype=dtype)
self.layernorm = nn.LayerNorm(in_features, dtype=dtype)
def forward(self, x):
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention1D(ParallelLayer):
"""Self attention layer for 1D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
# need to re-enable torch grad to enable fused optimization.
# self.layernorm = LayerNorm1D(
# hidden_size,
# dtype=dtype)
self.layernorm = nn.LayerNorm(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer1D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention1D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP1D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output

View File

@ -13,3 +13,6 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)

View File

@ -0,0 +1,411 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from colossalai import context
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from .layers import Linear1D_Col, Linear1D_Row
from ..base_layer import ParallelLayer
from .._common_utils import to_2tuple
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
class ViTMLP1D(ParallelLayer):
"""MLP layer for 1D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
skip_bias_add: bool = False,
weight_init='torch'
):
super().__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.skip_bias_add = skip_bias_add
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
dtype=dtype,
gather_output=False,
skip_bias_add=skip_dense_1_add_bias,
init_weight=weight_init,
init_bias=weight_init
)
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention1D(ParallelLayer):
"""Self-attention layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
weight_init='torch'
):
super().__init__()
self.hidden_size = hidden_size
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_bias = 'zero'
else:
init_bias = weight_init
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init,
init_bias=init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead1D(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
weight_init='torch'
):
super().__init__()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_weight = 'zero'
init_bias = 'zero'
else:
init_weight = weight_init
init_bias = weight_init
self.linear = Linear1D_Col(
hidden_size,
num_classes,
dtype=dtype,
gather_output=True,
init_weight=init_weight,
init_bias=init_bias
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTHead(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
):
super().__init__()
self.linear = nn.Linear(
hidden_size,
num_classes,
dtype=dtype
)
self._broadcast_linear_params()
def _broadcast_linear_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.linear.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.linear.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding1D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size
)
if weight_init == 'jax':
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
# sync
self._broadcast_conv_params()
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.proj.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.proj.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTTokenFuser1D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.empty(
1, self.num_patches + 1, self.embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=.02)
# move to cuda before broadcast
self.to(get_current_device())
dist.broadcast(self.pos_embed,
src=gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
self.pos_drop = nn.Dropout(p=drop_rate)
def forward(self, x: Tensor) -> Tensor:
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x.contiguous()

View File

@ -1,24 +1,30 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import numbers
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
import importlib
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide
from ._operation import FusedLayerNormAffineFunction1D
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
@LAYERS.register_module
class Linear1D_Col(ParallelLayer):
"""Linear layer with column parallelism.
@ -44,23 +50,29 @@ class Linear1D_Col(ParallelLayer):
output_size: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False):
gather_output: bool = False,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
super().__init__()
# Keep input parameters
self.input_size = in_features
self.output_size = output_size
self.in_features = in_features
self.out_features = output_size
self.gather_output = gather_output
self.skip_bias_add = not bias
self.skip_bias_add = skip_bias_add
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.output_size_per_partition = divide(output_size, world_size)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
self.output_size_per_partition, self.in_features,
**factory_kwargs))
if bias:
@ -72,6 +84,45 @@ class Linear1D_Col(ParallelLayer):
self.bias.zero_()
else:
self.register_parameter('bias', None)
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Set up backprop all-reduce.
@ -104,7 +155,7 @@ class Linear1D_Row(ParallelLayer):
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False
:param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional
"""
@ -113,7 +164,10 @@ class Linear1D_Row(ParallelLayer):
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_input: bool = False
parallel_input: bool = False,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
super().__init__()
@ -121,11 +175,13 @@ class Linear1D_Row(ParallelLayer):
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = not bias
self.skip_bias_add = skip_bias_add
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.input_size_per_partition = divide(in_features, world_size)
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
@ -146,9 +202,46 @@ class Linear1D_Row(ParallelLayer):
self.bias.zero_()
else:
self.register_parameter('bias', None)
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def reset_parameters(self) -> None:
init.xavier_normal_(self.weight)
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
dist.broadcast(self.bias,
src=gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
@ -163,4 +256,29 @@ class Linear1D_Row(ParallelLayer):
if not self.skip_bias_add:
output = output + self.bias
return output
return output
else:
return output, self.bias
@LAYERS.register_module
class MixedFusedLayerNorm1D(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm1D, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)

View File

@ -20,7 +20,6 @@ def matmul_2d(a,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
):
"""Matrix multiplication for 2D parallelism
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
@ -86,25 +85,30 @@ class Matmul_AB_2D(torch.autograd.Function):
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
A = A.reshape((-1, A_shape[-1])).contiguous()
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
B = B.reshape((-1, B_shape[-1])).contiguous()
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
A_temp = A.clone()
B_temp = B.clone()
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=gpc.get_group(row_parallel_mode))
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=gpc.get_group(col_parallel_mode))
torch.addmm(C, A_temp, B_temp, out=C)
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
op_a.wait()
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
for op in [op_a, op_b]:
op.wait()
for i in range(summa_dim):
src_a = i + summa_dim * row_rank
src_b = i + summa_dim * col_rank
src_a = src_a % summa_dim
src_b = src_b % summa_dim
A_temp = A_list[src_a]
B_temp = B_list[src_b]
torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape)
if ctx:
@ -499,36 +503,61 @@ class _LayerNorm_2D(torch.autograd.Function):
# return input_grad, None, None, None, None, None
class _ViT_Split_Input_2D(torch.autograd.Function):
class AllGatherLast(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
batch_size: int,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
# inputs: [b, s, h/q]
# output: [b/q, s, h/q]
ctx.BATCH_SIZE = batch_size
ctx.summa_dim = summa_dim
ctx.col_parallel_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
output = torch.chunk(inputs, summa_dim, dim=0)[row_rank]
output = output.clone()
return output
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
last_dim = summa_dim * inputs.size(-1)
outputs_shape = (last_dim,) + inputs.shape[:-1]
outputs = torch.empty(
outputs_shape, dtype=inputs.dtype, device=get_current_device())
dist.all_gather(
list(outputs.chunk(summa_dim, dim=0)),
inputs.permute(2, 0, 1).contiguous(),
group=gpc.get_group(col_parallel_mode)
)
outputs = outputs.permute(1, 2, 0).contiguous()
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [b/q, s, h/q]
# grads: [b, s, h/q]
grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
grads = torch.empty(grads_shape,
dtype=output_grad.dtype,
device=get_current_device())
dist.all_gather(list(grads.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.col_parallel_mode))
return grads, None, None, None
grad = output_grad.chunk(ctx.summa_dim, dim=-1)[ctx.row_rank]
return grad.contiguous(), None, None
class SplitFirst(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
ctx.summa_dim = summa_dim
ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = inputs.chunk(summa_dim, dim=0)[row_rank]
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(
grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(
list(grad.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode)
)
return grad, None, None

View File

@ -5,19 +5,21 @@ import math
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from ._operation import _ViT_Split_Input_2D
from colossalai.core import global_context as gpc
from ._operation import AllGatherLast, SplitFirst
from .layers import Linear2D
from .._common_utils import set_tensor_parallel_attribute
from .._common_utils import set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
@ -44,8 +46,8 @@ class ViTMLP2D(ParallelLayer):
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False
):
checkpoint: bool = False,
weight_init='torch'):
super().__init__()
assert_summa_initialization()
@ -53,27 +55,40 @@ class ViTMLP2D(ParallelLayer):
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear2D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
init_weight=weight_init, init_bias=weight_init,
skip_bias_add=skip_dense_1_add_bias
)
self.act = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2D(
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
@ -117,8 +132,8 @@ class ViTSelfAttention2D(ParallelLayer):
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False
):
checkpoint: bool = False,
weight_init='torch'):
super().__init__()
assert_summa_initialization()
@ -128,17 +143,24 @@ class ViTSelfAttention2D(ParallelLayer):
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_bias = 'zero'
else:
self.init_bias = weight_init
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
@ -146,7 +168,7 @@ class ViTSelfAttention2D(ParallelLayer):
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
@ -155,7 +177,7 @@ class ViTSelfAttention2D(ParallelLayer):
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
@ -165,7 +187,7 @@ class ViTSelfAttention2D(ParallelLayer):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
@ -199,14 +221,22 @@ class ViTHead2D(ParallelLayer):
hidden_size,
num_classes,
dtype=None,
):
weight_init='torch'):
super().__init__()
assert_summa_initialization()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
else:
self.init_weight = weight_init
self.init_bias = weight_init
self.summa_dim = get_summa_dim_from_env()
self.linear = Linear2D(
hidden_size,
num_classes,
dtype=dtype,
init_weight=self.init_weight, init_bias=self.init_bias
)
def forward(self, x: Tensor) -> Tensor:
@ -236,7 +266,8 @@ class ViTPatchEmbedding2D(ParallelLayer):
patch_size,
embed_dim,
in_chans=3,
flatten=True):
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
@ -249,39 +280,28 @@ class ViTPatchEmbedding2D(ParallelLayer):
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim // self.summa_dim
self.embed_dim = embed_dim // (self.summa_dim ** 2)
with seed(ParallelMode.TENSOR):
# ensure the partitions are initialized differently
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size
stride=patch_size,
device=get_current_device()
)
self._set_tensor_parallel_attribute()
# sync
self._broadcast_conv_params()
self.proj.weight.register_hook(self._sync_grad_during_backward)
self.proj.bias.register_hook(self._sync_grad_during_backward)
if weight_init == 'jax':
with seed(ParallelMode.TENSOR):
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias)
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(self.proj.weight, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
dist.broadcast(self.proj.bias, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
def _sync_grad_during_backward(self, grad: Tensor) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
@ -293,6 +313,24 @@ class ViTPatchEmbedding2D(ParallelLayer):
return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
x = AllGatherLast.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = SplitFirst.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
return x
@LAYERS.register_module
class ViTTokenFuser2D(ParallelLayer):
"""
@ -328,64 +366,32 @@ class ViTTokenFuser2D(ParallelLayer):
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim // self.summa_dim))
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + 1, self.embed_dim // self.summa_dim))
(1, 1, self.embed_dim // (self.summa_dim ** 2)),
device=get_current_device()))
self.pos_embed = nn.Parameter(torch.empty(
(1, self.num_patches + 1, self.embed_dim // (self.summa_dim ** 2)),
device=get_current_device()))
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.pos_embed, std=.02)
# move to cuda before broadcast
self.to(get_current_device())
# sync param in both forward and backward
_cls_token = self.cls_token.view(-1)
_pos_embed = self.pos_embed.view(-1)
self._param = torch.cat([_cls_token, _pos_embed], dim=0)
self._broadcast_params(self._param)
self._param.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def _broadcast_params(self, param) -> None:
" broadcast to all column ranks for data consistency "
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
col_group = gpc.get_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(param, src=ranks_in_col[0],
group=col_group)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
cls_token = AllGatherLast.apply(
self.cls_token, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
pos_embed = AllGatherLast.apply(
self.pos_embed, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = x + pos_embed
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x + self.pos_embed)
x = self.pos_drop(x)
return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
batch_size = x.size(0)
return _ViT_Split_Input_2D.apply(
x,
batch_size,
self.summa_dim,
ParallelMode.PARALLEL_2D_COL
)

View File

@ -11,7 +11,7 @@ from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D
from ._utils import get_summa_dim_from_env, assert_summa_initialization
from .._common_utils import divide, set_tensor_parallel_attribute
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer
@ -36,8 +36,9 @@ class Linear2D(ParallelLayer):
out_features: int,
bias: bool = True,
dtype=None,
skip_bias_add: bool = False
):
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'):
super().__init__()
self.in_features = in_features
@ -72,31 +73,45 @@ class Linear2D(ParallelLayer):
self.register_parameter('bias', None)
# initialize parameters
self.reset_parameters()
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def reset_parameters(self) -> None:
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
fan_in, fan_out = self.in_features, self.out_features
# init weight
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
if self.bias is not None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
@ -192,28 +207,19 @@ class LayerNorm2D(ParallelLayer):
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if self.row_rank == 0:
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
else:
self.gamma = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.beta = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.gamma)
set_tensor_parallel_attribute(self.beta)
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():

View File

@ -1,11 +1,10 @@
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Sum_2p5D, Add_Bias_2p5D
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Add_Bias_2p5D
from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D
from ._vit import (ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D,
ViTInputSplitter2p5D)
from ._vit import ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D, ViTInputSplitter2p5D
from .layers import Linear2p5D, LayerNorm2p5D
__all__ = [
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Sum_2p5D', 'Add_Bias_2p5D',
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Add_Bias_2p5D',
'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D',
'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D',
'ViTInputSplitter2p5D',

View File

@ -6,7 +6,8 @@ from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device, empty_cache
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
def get_parallel_group(parallel_mode: ParallelMode):
@ -26,18 +27,17 @@ class Matmul_AB_2p5D(torch.autograd.Function):
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
@ -49,41 +49,43 @@ class Matmul_AB_2p5D(torch.autograd.Function):
assert A.shape[-1] == B.shape[-2], \
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
empty_cache()
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
A = A.reshape((-1, A_shape[-1])).contiguous()
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
B = B.reshape((-1, B_shape[-1])).contiguous()
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim):
A_temp = A.clone()
B_temp = B.clone()
src_a = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=get_parallel_group(row_parallel_mode))
src_b = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=get_parallel_group(col_parallel_mode))
torch.addmm(C, A_temp, B_temp, out=C)
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
op_a.wait()
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
for op in [op_a, op_b]:
op.wait()
for i in range(tesseract_dim):
src_a = i + tesseract_dim * row_rank
src_b = i + tesseract_dim * col_rank
src_a = src_a % tesseract_dim
src_b = src_b % tesseract_dim
A_temp = A_list[src_a]
B_temp = B_list[src_b]
torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape)
if ctx:
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
@ -94,34 +96,32 @@ class Matmul_AB_2p5D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2p5D.forward(
None,
output_grad, B,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2p5D.forward(
None,
A, output_grad,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
with torch.no_grad():
A_grad = Matmul_ABT_2p5D.apply(
output_grad, B,
ctx.tesseract_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2p5D.apply(
A, output_grad,
ctx.tesseract_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@ -130,18 +130,17 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
@ -151,7 +150,6 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
empty_cache()
if ctx:
ctx.save_for_backward(A, B)
@ -180,13 +178,11 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
if ctx:
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
@ -197,34 +193,32 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_AB_2p5D.forward(
None,
output_grad, B,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2p5D.forward(
None,
output_grad, A,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
with torch.no_grad():
A_grad = Matmul_AB_2p5D.apply(
output_grad, B,
ctx.tesseract_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2p5D.apply(
output_grad, A,
ctx.tesseract_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@ -233,18 +227,17 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
@ -253,7 +246,6 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
assert A.shape[-2] == B.shape[-2], \
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
empty_cache()
if ctx:
ctx.save_for_backward(A, B)
@ -284,13 +276,11 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
if ctx:
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
@ -301,34 +291,32 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2p5D.forward(
None,
B, output_grad,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2p5D.forward(
None,
A, output_grad,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
with torch.no_grad():
A_grad = Matmul_ABT_2p5D.apply(
B, output_grad,
ctx.tesseract_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2p5D.apply(
A, output_grad,
ctx.tesseract_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@ -337,18 +325,16 @@ class Add_Bias_2p5D(torch.autograd.Function):
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
tesseract_dim: int,
tesseract_dep: int,
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
@ -371,10 +357,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.bias = skip_bias_add
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
@ -388,15 +371,13 @@ class Add_Bias_2p5D(torch.autograd.Function):
return output
@staticmethod
def backward(ctx, output_grad):
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_rank = ctx.row_rank
col_rank = ctx.col_rank
dep_rank = ctx.dep_rank
tesseract_dim = ctx.tesseract_dim
tesseract_dep = ctx.tesseract_dep
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
dep_parallel_mode = ctx.dep_parallel_mode
data_parallel_rank = ctx.data_parallel_rank
pipeline_parallel_rank = ctx.pipeline_parallel_rank
pipeline_parallel_size = ctx.pipeline_parallel_size
@ -428,29 +409,25 @@ class Add_Bias_2p5D(torch.autograd.Function):
class _LayerNorm_2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode) -> Tensor:
row_parallel_mode: ParallelMode) -> Tensor:
input = input - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx.hidden_size = hidden_size
output = input * Var_x
ctx.save_for_backward(output, Var_x)
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
dep_parallel_mode = ctx.dep_parallel_mode
x, Var_x = ctx.saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
with torch.no_grad():
@ -473,63 +450,122 @@ class _LayerNorm_2p5D(torch.autograd.Function):
return input_grad, None, None, None, None, None, None
class Sum_2p5D(torch.autograd.Function):
"""Compute the sum of input tensors
"""
# class Sum_2p5D(torch.autograd.Function):
# """Compute the sum of input tensors
# """
# @staticmethod
# def forward(ctx,
# inputs,
# dim,
# tesseract_dim,
# row_parallel_mode,
# keepdim=False):
# # input: [b/q, s, h/q]
# ctx.save_for_backward(inputs)
# # sum: [b/q, s]
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
# torch.distributed.all_reduce(
# out, group=gpc.get_group(row_parallel_mode))
# return out
# @staticmethod
# def backward(ctx, output_grad):
# with torch.no_grad():
# inputs = ctx.saved_tensors
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
# return input_grad, None, None, None, None, None
# class _ViT_Split_2p5D(torch.autograd.Function):
# @staticmethod
# @custom_fwd(cast_inputs=torch.float16)
# def forward(ctx, inputs, batch_size,
# tesseract_dim, tesseract_dep,
# xz_parallel_mode):
# # inputs: [b, s, h/q]
# # output: [b/dq, s, h/q]
# ctx.BATCH_SIZE = batch_size
# ctx.tesseract_dim = tesseract_dim
# ctx.tesseract_dep = tesseract_dep
# ctx.xz_parallel_mode = xz_parallel_mode
# xz_rank = gpc.get_local_rank(xz_parallel_mode)
# output = torch.chunk(inputs, tesseract_dep *
# tesseract_dim, dim=0)[xz_rank]
# output = output.clone()
# return output
# @staticmethod
# @custom_bwd
# def backward(ctx, output_grad):
# # output_grad: [b/dq, s, h/q]
# # grads: [b, s, h/q]
# # *
# grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
# grads = torch.empty(grads_shape,
# dtype=output_grad.dtype,
# device=get_current_device())
# dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
# output_grad.contiguous(),
# group=get_parallel_group(ctx.xz_parallel_mode))
# return grads, None, None, None, None
class AllGatherLast(torch.autograd.Function):
@staticmethod
def forward(ctx,
inputs,
dim,
tesseract_dim,
row_parallel_mode,
keepdim=False):
# input: [b/q, s, h/q]
empty_cache()
ctx.save_for_backward(inputs)
# sum: [b/q, s]
out = torch.sum(inputs, dim=dim, keepdim=keepdim)
torch.distributed.all_reduce(
out, group=gpc.get_group(row_parallel_mode))
return out
@staticmethod
def backward(ctx, output_grad):
with torch.no_grad():
inputs = ctx.saved_tensors
input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
return input_grad, None, None, None, None, None
class _ViT_Split_2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, batch_size,
tesseract_dim, tesseract_dep,
xz_parallel_mode):
# inputs: [b, s, h/q]
# output: [b/dq, s, h/q]
empty_cache()
ctx.batch_size = batch_size
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
tesseract_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.xz_parallel_mode = xz_parallel_mode
xz_rank = gpc.get_local_rank(xz_parallel_mode)
output = torch.chunk(inputs, tesseract_dep *
tesseract_dim, dim=0)[xz_rank]
output = output.clone()
return output
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
last_dim = tesseract_dim * inputs.size(-1)
outputs_shape = (last_dim,) + inputs.shape[:-1]
outputs = torch.empty(
outputs_shape, dtype=inputs.dtype, device=get_current_device())
dist.all_gather(
list(outputs.chunk(tesseract_dim, dim=0)),
inputs.permute(2, 0, 1).contiguous(),
group=gpc.get_group(col_parallel_mode)
)
outputs = outputs.permute(1, 2, 0).contiguous()
return outputs
@staticmethod
def backward(ctx, output_grad):
# output_grad: [b/dq, s, h/q]
# grads: [b, s, h/q]
# *
grads_shape = (ctx.batch_size,) + output_grad.shape[1:]
grads = torch.empty(grads_shape,
dtype=output_grad.dtype,
device=get_current_device())
dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
output_grad.contiguous(),
group=get_parallel_group(ctx.xz_parallel_mode))
return grads, None, None, None, None
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = output_grad.chunk(ctx.tesseract_dim, dim=-1)[ctx.row_rank]
return grad.contiguous(), None, None
class SplitFirst(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
tesseract_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
ctx.tesseract_dim = tesseract_dim
ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = inputs.chunk(tesseract_dim, dim=0)[row_rank]
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(
grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(
list(grad.chunk(ctx.tesseract_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode)
)
return grad, None, None

View File

@ -12,10 +12,11 @@ from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from .layers import Linear2p5D, LayerNorm2p5D
from .._common_utils import ACT2FN
from ..base_layer import ParallelLayer
@LAYERS.register_module
class TransformerMLP2p5D(nn.Module):
class TransformerMLP2p5D(ParallelLayer):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
@ -36,21 +37,24 @@ class TransformerMLP2p5D(nn.Module):
def __init__(self,
in_features: int,
mlp_ratio: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.in_features = in_features
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear2p5D(
in_features,
mlp_ratio * in_features,
dtype=dtype
int(mlp_ratio * in_features),
dtype=dtype,
skip_bias_add=skip_bias_add
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
@ -59,24 +63,34 @@ class TransformerMLP2p5D(nn.Module):
# Project back to h.
self.dense_2 = Linear2p5D(
mlp_ratio * in_features,
int(mlp_ratio * in_features),
in_features,
dtype=dtype
dtype=dtype,
skip_bias_add=skip_bias_add
)
self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2p5D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
intermediate_output = self.dense_1(x)
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
output = self.dense_2(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention2p5D(nn.Module):
class TransformerSelfAttention2p5D(ParallelLayer):
"""Self attention layer for 2.5D parallel Transformer
:param hidden_size: hidden size
@ -92,10 +106,10 @@ class TransformerSelfAttention2p5D(nn.Module):
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
@ -127,7 +141,7 @@ class TransformerSelfAttention2p5D(nn.Module):
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
@ -136,7 +150,7 @@ class TransformerSelfAttention2p5D(nn.Module):
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
@ -144,7 +158,7 @@ class TransformerSelfAttention2p5D(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
@ -155,7 +169,7 @@ class TransformerSelfAttention2p5D(nn.Module):
@LAYERS.register_module
class TransformerLayer2p5D(nn.Module):
class TransformerLayer2p5D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
@ -175,10 +189,10 @@ class TransformerLayer2p5D(nn.Module):
"""
def __init__(self,
hidden_size,
num_attention_heads,
act_func='gelu',
mlp_ratio=4,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,

View File

@ -5,22 +5,25 @@ import math
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from ._operation import _ViT_Split_2p5D
from ._operation import AllGatherLast, SplitFirst
from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from .layers import Linear2p5D
from .._common_utils import ACT2FN, divide, CheckpointModule
from .._common_utils import set_tensor_parallel_attribute
from ..base_layer import ParallelLayer
from ..fused_bias_gelu import bias_gelu_impl
from .._common_utils import (ACT2FN, divide, to_2tuple,
set_tensor_parallel_attribute_by_partition)
@LAYERS.register_module
class ViTMLP2p5D(CheckpointModule):
class ViTMLP2p5D(ParallelLayer):
"""MLP layer for 2.5D parallel Vision Transformer
:param in_features: size of each input sample
@ -43,19 +46,32 @@ class ViTMLP2p5D(CheckpointModule):
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False
checkpoint: bool = False,
weight_init='torch'
):
super().__init__(checkpoint=checkpoint)
super().__init__()
assert_tesseract_initialization()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear2p5D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
init_weight=weight_init,
init_bias=weight_init,
skip_bias_add=skip_dense_1_add_bias
)
self.act = ACT2FN[act_func]
@ -65,20 +81,39 @@ class ViTMLP2p5D(CheckpointModule):
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
init_weight=weight_init,
init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
intermediate_output = self.dropout(intermediate_output)
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention2p5D(CheckpointModule):
class ViTSelfAttention2p5D(ParallelLayer):
"""Self-attention layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
@ -101,9 +136,10 @@ class ViTSelfAttention2p5D(CheckpointModule):
attention_dropout_prob,
hidden_dropout_prob,
dtype=None,
checkpoint: bool = False
checkpoint: bool = False,
weight_init='torch'
):
super().__init__(checkpoint=checkpoint)
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
@ -112,19 +148,30 @@ class ViTSelfAttention2p5D(CheckpointModule):
num_attention_heads, self.tesseract_dim) # *
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_bias = 'zero'
else:
self.init_bias = weight_init
self.query_key_value = Linear2p5D(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init,
init_bias=self.init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2p5D(
hidden_size,
hidden_size,
dtype=dtype,
init_weight=weight_init,
init_bias=self.init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
@ -140,8 +187,10 @@ class ViTSelfAttention2p5D(CheckpointModule):
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
@ -150,12 +199,22 @@ class ViTSelfAttention2p5D(CheckpointModule):
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead2p5D(nn.Module):
class ViTHead2p5D(ParallelLayer):
"""Output layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
@ -170,13 +229,24 @@ class ViTHead2p5D(nn.Module):
hidden_size,
num_classes,
dtype=None,
weight_init='torch'
):
super().__init__()
assert_tesseract_initialization()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
else:
self.init_weight = weight_init
self.init_bias = weight_init
self.linear = Linear2p5D(
hidden_size,
num_classes,
dtype=dtype,
init_weight=self.init_weight,
init_bias=self.init_bias
)
def forward(self, x: Tensor) -> Tensor:
@ -186,7 +256,7 @@ class ViTHead2p5D(nn.Module):
@LAYERS.register_module
class ViTPatchEmbedding2p5D(nn.Module):
class ViTPatchEmbedding2p5D(ParallelLayer):
""" 2.5D Image to Patch Embedding
:param img_size: iamge size
@ -206,7 +276,8 @@ class ViTPatchEmbedding2p5D(nn.Module):
patch_size,
embed_dim,
in_chans=3,
flatten=True):
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
@ -219,34 +290,28 @@ class ViTPatchEmbedding2p5D(nn.Module):
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim // self.tesseract_dim # *
self.embed_dim = embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2) # *
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size,
)
with seed(ParallelMode.TENSOR):
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size,
device=get_current_device()
)
self._set_tensor_parallel_attribute()
# move self to cuda before sync
self.to(get_current_device())
if weight_init == 'jax':
with seed(ParallelMode.TENSOR):
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
# sync
self._broadcast_conv_params()
self.proj.weight.register_hook(self._sync_grad_during_backward)
self.proj.bias.register_hook(self._sync_grad_during_backward)
def _broadcast_conv_params(self) -> None:
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
dist.broadcast(self.proj.weight, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
dist.broadcast(self.proj.bias, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
def _sync_grad_during_backward(self, grad: Tensor) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2P5D_XZ))
grad = grad / self.tesseract_dim / self.tesseract_dep # *
return grad
def _set_tensor_parallel_attribute(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
@ -259,7 +324,25 @@ class ViTPatchEmbedding2p5D(nn.Module):
@LAYERS.register_module
class ViTTokenFuser2p5D(nn.Module):
class ViTInputSplitter2p5D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
def forward(self, x: Tensor) -> Tensor:
x = AllGatherLast.apply(
x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
x = SplitFirst.apply(
x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
return x
@LAYERS.register_module
class ViTTokenFuser2p5D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
@ -293,59 +376,46 @@ class ViTTokenFuser2p5D(nn.Module):
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim // self.tesseract_dim)) # *
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + 1, self.embed_dim // self.tesseract_dim)) # *
(1, 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
device=get_current_device()))
self.pos_embed = nn.Parameter(torch.empty(
(1, self.num_patches + 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
device=get_current_device()))
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.pos_embed, std=.02)
# move to cuda before broadcast
self.to(get_current_device())
self._broadcast_params()
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
def _broadcast_params(self) -> None:
def _broadcast_params(self, param) -> None:
" broadcast to all column ranks for data consistency "
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
dist.broadcast(self.cls_token, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
dist.broadcast(self.pos_embed, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
if self.tesseract_dep > 1:
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
xz_group = gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)
dist.broadcast(param, src=xz_rank[0],
group=xz_group)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2P5D_XZ))
grad = grad / self.tesseract_dim / self.tesseract_dep # *
grad = grad / self.tesseract_dim # / self.tesseract_dep # *
return grad
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
cls_token = AllGatherLast.apply(
self.cls_token, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
pos_embed = AllGatherLast.apply(
self.pos_embed, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
x = x + pos_embed
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x)
return x
@LAYERS.register_module
class ViTInputSplitter2p5D(nn.Module):
def __init__(self):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
def forward(self, x: Tensor) -> Tensor:
batch_size = x.size(0)
return _ViT_Split_2p5D.apply(
x,
batch_size,
self.tesseract_dim,
self.tesseract_dep,
ParallelMode.PARALLEL_2P5D_XZ,
)

View File

@ -10,7 +10,7 @@ from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2p5D, Add_Bias_2p5D, _LayerNorm_2p5D
from ._utils import get_tesseract_dim_dep_from_env, assert_tesseract_initialization
from .._common_utils import divide, set_tensor_parallel_attribute
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer
@ -33,7 +33,9 @@ class Linear2p5D(ParallelLayer):
out_features: int,
bias: bool = True,
dtype=None,
skip_bias_add: bool = False
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
super().__init__()
@ -46,7 +48,7 @@ class Linear2p5D(ParallelLayer):
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.input_size_per_partition = divide(in_features, self.tesseract_dim)
@ -69,46 +71,59 @@ class Linear2p5D(ParallelLayer):
self.register_parameter('bias', None)
# initialize parameters
self.reset_parameters()
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def reset_parameters(self) -> None:
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
fan_in, fan_out = self.in_features, self.out_features
# init weight
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
if self.bias is not None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2p5D.apply(
x,
self.weight,
self.tesseract_dim,
self.tesseract_dep,
out_shape,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
@ -121,11 +136,9 @@ class Linear2p5D(ParallelLayer):
None,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim, self.tesseract_dep,
self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
@ -138,11 +151,9 @@ class Linear2p5D(ParallelLayer):
output,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim, self.tesseract_dep,
self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
@ -168,6 +179,7 @@ class LayerNorm2p5D(ParallelLayer):
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
normalized_shape: int,
eps: float = 1e-05,
@ -184,7 +196,7 @@ class LayerNorm2p5D(ParallelLayer):
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.partitioned_partition = divide(
@ -193,27 +205,19 @@ class LayerNorm2p5D(ParallelLayer):
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if self.row_rank == 0:
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
else:
self.gamma = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.beta = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.gamma)
set_tensor_parallel_attribute(self.beta)
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
@ -233,16 +237,12 @@ class LayerNorm2p5D(ParallelLayer):
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = _LayerNorm_2p5D.apply(x, E_x, Var_x, self.normalized_shape,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP)
ParallelMode.PARALLEL_2P5D_ROW)
bias = Add_Bias_2p5D.apply(
None, self.beta, self.partitioned_partition,
self.tesseract_dim, self.tesseract_dep,
self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
@ -251,11 +251,9 @@ class LayerNorm2p5D(ParallelLayer):
)
scale = Add_Bias_2p5D.apply(
None, self.gamma, self.partitioned_partition,
self.tesseract_dim, self.tesseract_dep,
self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,

View File

@ -1,21 +1,223 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Tuple
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from colossalai.communication import all_gather, reduce_scatter, scatter
from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import empty_cache, get_current_device
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
class linear_3d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
assert input_.shape[-1] == weight.shape[0], \
'Invalid shapes: input = {}, weight = {}.'.format(input_.shape, weight.shape)
ctx.use_bias = bias is not None
input_ = all_gather(input_, input_dim, input_parallel_mode)
input_ = torch.cat(input_, dim=input_dim)
# weight = all_gather(weight, weight_dim, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, output_dim, output_parallel_mode)
if bias is not None:
# ranks_in_group = gpc.get_ranks_in_group(output_parallel_mode)
# src_rank = ranks_in_group[gpc.get_local_rank(input_parallel_mode)]
# dist.broadcast(bias,
# src=src_rank,
# group=gpc.get_group(output_parallel_mode))
# bias = all_gather(bias, -1, weight_parallel_mode)
output += bias
# ctx.src_rank = src_rank
# ctx.save_for_backward(input_, weight)
# output = torch.matmul(input_, weight)
# dist.all_reduce(output, group=gpc.get_group(output_parallel_mode))
# output += bias
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
ctx.input_dim = input_dim
ctx.weight_dim = weight_dim
ctx.output_dim = output_dim
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
# input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
# dist.all_reduce(input_grad,
# group=gpc.get_group(ctx.input_parallel_mode))
# weight_grad = torch.matmul(
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
# output_grad.reshape(-1, output_grad.shape[-1]))
# dist.all_reduce(weight_grad,
# group=gpc.get_group(ctx.weight_parallel_mode))
# bias_grad = torch.sum(output_grad,
# dim=tuple(
# range(len(output_grad.shape))[:-1]))
# bias_grad = reduce_scatter(bias_grad, -1,
# ctx.weight_parallel_mode)
# dist.reduce(bias_grad,
# dst=ctx.src_rank,
# group=gpc.get_group(ctx.output_parallel_mode))
# if gpc.get_local_rank(
# ctx.output_parallel_mode) != gpc.get_local_rank(
# ctx.input_parallel_mode):
# bias_grad = None
# input_ = all_gather(input_, ctx.input_dim, ctx.input_parallel_mode)
# weight = all_gather(weight, ctx.weight_dim,
# ctx.weight_parallel_mode)
output_grad = all_gather(output_grad, ctx.output_dim,
ctx.output_parallel_mode)
output_grad = torch.cat(output_grad, dim=ctx.output_dim)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, ctx.input_dim,
ctx.input_parallel_mode,
async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
output_grad.reshape(-1, output_grad.shape[-1]))
# weight_grad = torch.matmul(
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
# output_grad.reshape(-1, output_grad.shape[-1]))
# weight_grad = reduce_scatter(weight_grad, ctx.weight_dim,
# ctx.weight_parallel_mode)
if ctx.use_bias:
bias_grad = torch.sum(output_grad,
dim=tuple(
range(len(output_grad.shape))[:-1]))
# bias_grad =all_reduce(bias_grad, ctx.output_parallel_mode)
# dist.all_reduce(bias_grad,
# group=gpc.get_group(ctx.weight_parallel_mode))
weight_grad = torch.cat([weight_grad, torch.unsqueeze(bias_grad, dim=0)])
weight_grad, weight_op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
input_op.wait()
weight_op.wait()
if ctx.use_bias:
bias_grad = weight_grad[-1]
weight_grad = weight_grad[:-1]
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
class layer_norm_3d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, weight: Tensor, bias: Tensor,
normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
# mean = torch.sum(input_, dim=-1)
# dist.all_reduce(mean, group=gpc.get_group(output_parallel_mode))
# mean /= normalized_shape
# mu = input_ - mean
# var = torch.sum(torch.pow(mu, 2), dim=-1)
# dist.all_reduce(var, group=gpc.get_group(output_parallel_mode))
# var /= normalized_shape
# std_dev = torch.sqrt(var + eps)
# ctx.save_for_backward(input_, mu, std_dev, weight)
# output = weight * mu / std_dev + bias
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True),
output_parallel_mode) / normalized_shape
mu = input_ - mean
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True),
output_parallel_mode) / normalized_shape
sigma = torch.sqrt(var + eps)
# ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
# src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
# transforms = torch.stack([weight, bias]).contiguous()
# dist.broadcast(transforms,
# src=src_rank,
# group=gpc.get_group(input_parallel_mode))
# transforms = all_gather(transforms, -1, weight_parallel_mode)
# weight, bias = transforms[0], transforms[1]
ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma
output = weight * z + bias
# ctx.src_rank = src_rank
ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors
with torch.no_grad():
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
grads = torch.stack([bias_grad, weight_grad]).contiguous()
grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1]))
grads = all_reduce(grads, ctx.weight_parallel_mode)
grads = all_reduce(grads, ctx.input_parallel_mode)
bias_grad, weight_grad = grads[0], grads[1]
# grads = reduce_scatter(grads, -1, ctx.weight_parallel_mode)
# dist.reduce(grads,
# dst=ctx.src_rank,
# group=gpc.get_group(ctx.input_parallel_mode))
# if gpc.get_local_rank(
# ctx.input_parallel_mode) == gpc.get_local_rank(
# ctx.output_parallel_mode):
# bias_grad, weight_grad = grads[0], grads[1]
# else:
# bias_grad, weight_grad = None, None
dz = output_grad * weight
dvar = dz * mu * (-0.5) * sigma**(-3)
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None
class Matmul_AB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@ -29,7 +231,6 @@ class Matmul_AB_3D(torch.autograd.Function):
# A: [m/q^2, n, k/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, h/q]
empty_cache()
ctx.save_for_backward(A, B)
assert A.shape[-1] == B.shape[0], \
@ -52,6 +253,7 @@ class Matmul_AB_3D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
@ -72,6 +274,7 @@ class Matmul_ABT_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@ -85,7 +288,6 @@ class Matmul_ABT_3D(torch.autograd.Function):
# A: [m/q^2, n, h/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, k/q]
empty_cache()
ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode)
@ -105,6 +307,7 @@ class Matmul_ABT_3D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
@ -125,6 +328,7 @@ class Matmul_ATB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@ -138,7 +342,6 @@ class Matmul_ATB_3D(torch.autograd.Function):
# A: [m/q^2, n, k/q]
# B: [m/q^2, n, h/q]
# C: [k/q, h/q^2]
empty_cache()
ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode)
@ -160,6 +363,7 @@ class Matmul_ATB_3D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
@ -180,6 +384,7 @@ class Add_3D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
@ -206,6 +411,7 @@ class Add_3D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q]
with torch.no_grad():
@ -217,8 +423,8 @@ class Add_3D(torch.autograd.Function):
dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank(
ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode):
ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode):
bias_grad = None
return output_grad, bias_grad, None, None, None, None
@ -227,6 +433,7 @@ class Mul_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A * b`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
@ -243,7 +450,7 @@ class Mul_3D(torch.autograd.Function):
# [h/q]
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
empty_cache()
# empty_cache()
ctx.save_for_backward(input_, bias_temp)
out = torch.mul(input_, bias_temp)
@ -257,6 +464,7 @@ class Mul_3D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q]
with torch.no_grad():
@ -272,8 +480,8 @@ class Mul_3D(torch.autograd.Function):
dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank(
ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode):
ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode):
bias_grad = None
return input_grad, bias_grad, None, None, None, None
@ -282,6 +490,7 @@ class Sum_3D(torch.autograd.Function):
"""Compute the sum of input tensors
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input_: Tensor,
dim: int,
@ -299,6 +508,7 @@ class Sum_3D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
with torch.no_grad():
output_grad = output_grad.contiguous()
@ -315,35 +525,39 @@ class Reduce_3D(torch.autograd.Function):
"""Reduce input tensors
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, depth: int,
parallel_mode: ParallelMode) -> Tensor:
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_.clone()
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
return output_grad, None, None
class Slice_3D(torch.autograd.Function):
"""Slice input tensor
"""
@staticmethod
def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
parallel_mode: ParallelMode) -> Tensor:
rank = gpc.get_local_rank(parallel_mode)
out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
# class Slice_3D(torch.autograd.Function):
# """Slice input tensor
# """
# @staticmethod
# @custom_fwd(cast_inputs=torch.float16)
# def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
# parallel_mode: ParallelMode) -> Tensor:
# rank = gpc.get_local_rank(parallel_mode)
# out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
ctx.depth = depth
ctx.parallel_mode = parallel_mode
ctx.dim = dim
ctx.input_shape = input_.shape
# ctx.depth = depth
# ctx.parallel_mode = parallel_mode
# ctx.dim = dim
# ctx.input_shape = input_.shape
return out
# return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
with torch.no_grad():
input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
input_grad.reshape(ctx.input_shape)
return input_grad, None, None, None
# @staticmethod
# @custom_bwd
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# with torch.no_grad():
# input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
# input_grad.reshape(ctx.input_shape)
# return input_grad, None, None, None

View File

@ -3,7 +3,8 @@
import os
from colossalai.constants import DEPTH_3D
from colossalai.constants import (DEPTH_3D, INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor
@ -23,6 +24,10 @@ def get_depth_from_env() -> int:
)
def get_parallel_mode_from_env(group):
return getattr(ParallelMode, os.environ[group])
def get_last_group(a, b):
mapping = {
ParallelMode.PARALLEL_3D_INPUT: 'A',
@ -41,6 +46,11 @@ def get_last_group(a, b):
return ParallelMode.PARALLEL_3D_OUTPUT
def swap_in_out_group():
os.environ[INPUT_GROUP_3D], os.environ[OUTPUT_GROUP_3D] = \
os.environ[OUTPUT_GROUP_3D], os.environ[INPUT_GROUP_3D]
def dbg_check_shape(tensor: Tensor, shape: tuple):
rank = gpc.get_global_rank()
if rank == 0:

View File

@ -1,17 +1,20 @@
import math
from typing import Tuple
import os
from typing import Tuple, Optional
import torch
import torch.distributed as dist
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.nn.init import init_bias_, init_weight_
from colossalai.utils import checkpoint, get_current_device
from torch import Tensor, dtype, nn
from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute
from ..vanilla_vision_transformer.layers import to_2tuple
from ._utils import get_depth_from_env
from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute_by_size, to_2tuple
from ._utils import get_depth_from_env, get_parallel_mode_from_env, get_last_group
from .layers import Linear3D
@ -32,34 +35,42 @@ class ViTPatchEmbedding3D(nn.Module):
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
drop_prob: float,
flatten: bool = True):
flatten: bool = True,
init_method: str = 'torch'):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.in_chans = in_chans
self.embed_size = embed_size
self.embed_size_per_partition = divide(self.embed_size, self.depth)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.init_weight = 'torch'
self.init_bias = 'torch'
if init_method == 'jax':
self.init_weight = 'jax_embed'
self.init_bias = 'zero'
with seed(ParallelMode.TENSOR):
self.proj = nn.Conv2d(in_chans,
self.embed_size_per_partition,
kernel_size=patch_size,
stride=patch_size)
self.proj = nn.Conv2d(self.in_chans,
self.embed_size_per_partition,
kernel_size=patch_size,
stride=patch_size)
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.embed_size_per_partition))
@ -68,23 +79,26 @@ class ViTPatchEmbedding3D(nn.Module):
self.embed_size_per_partition))
self.pos_drop = nn.Dropout(drop_prob)
self._sync_parameters()
self.proj.weight.register_hook(self._sync_grad_hook)
self.proj.bias.register_hook(self._sync_grad_hook)
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
self._set_tensor_parallel_attribute()
self.reset_parameters(self.init_weight, self.init_bias)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias)
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_size(self.proj.weight, self.in_chans * self.embed_size * self.num_patches)
set_tensor_parallel_attribute_by_size(self.proj.bias, self.embed_size)
set_tensor_parallel_attribute_by_size(self.cls_token, 1 * 1 * self.embed_size)
set_tensor_parallel_attribute_by_size(self.pos_embed, 1 * (self.num_patches + 1) * self.embed_size)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
def reset_parameters(self, init_weight, init_bias):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.proj.weight)
# std = math.sqrt(1.0 / fan_in)
# nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
# nn.init.zeros_(self.proj.bias)
if init_weight != 'torch':
init_weight_(self.proj.weight, fan_in, init_method=init_weight)
init_bias_(self.pos_embed, fan_in, init_method=init_weight)
if init_bias != 'torch':
init_bias_(self.proj.bias, fan_in, init_method=init_bias)
def _sync_parameters(self):
self.to(get_current_device())
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
dist.broadcast(self.proj.weight,
@ -100,10 +114,11 @@ class ViTPatchEmbedding3D(nn.Module):
dist.broadcast(self.proj.bias,
src=input_src_rank,
group=gpc.get_group(self.input_parallel_mode))
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias)
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
self.proj.weight.register_hook(self._sync_grad_hook)
self.proj.bias.register_hook(self._sync_grad_hook)
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode))
@ -111,6 +126,12 @@ class ViTPatchEmbedding3D(nn.Module):
return grad
def forward(self, x: Tensor) -> Tensor:
# split a partition from inputs
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.weight_parallel_mode)].contiguous()
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.input_parallel_mode)].contiguous()
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
@ -118,12 +139,6 @@ class ViTPatchEmbedding3D(nn.Module):
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
# split a partition from embedded states
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.weight_parallel_mode)].contiguous()
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.input_parallel_mode)].contiguous()
# add cls token & pos embedding
# [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
@ -158,6 +173,7 @@ class ViTSelfAttention3D(nn.Module):
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
@ -165,41 +181,52 @@ class ViTSelfAttention3D(nn.Module):
hidden_dropout_prob: float,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False):
checkpoint: bool = False,
init_method: str = 'torch'):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
# self.weight_parallel_mode)
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.depth)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
self.init_weight = 'torch'
self.init_bias = 'torch'
if init_method == 'jax':
self.init_weight = 'jax'
self.init_bias = 'zero'
self.query_key_value = Linear3D(self.hidden_size,
3 * self.hidden_size,
self.input_parallel_mode,
self.weight_parallel_mode,
# self.input_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias)
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.attention_dropout = nn.Dropout(attention_probs_dropout_prob)
self.dense = Linear3D(self.hidden_size,
self.hidden_size,
self.output_parallel_mode,
self.weight_parallel_mode,
# self.output_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias)
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
# def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
# return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(query_key_value,
@ -259,6 +286,7 @@ class ViTMLP3D(nn.Module):
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
hidden_size: int,
mlp_ratio: int,
@ -266,33 +294,41 @@ class ViTMLP3D(nn.Module):
hidden_act: str = 'gelu',
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False):
checkpoint: bool = False,
init_method: str = 'torch'):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
# self.depth = get_depth_from_env()
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
# self.weight_parallel_mode)
self.hidden_size = hidden_size
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.init_weight = init_method
self.init_bias = init_method
self.dense_1 = Linear3D(self.hidden_size,
self.mlp_ratio * self.hidden_size,
self.input_parallel_mode,
self.weight_parallel_mode,
# self.input_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias)
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.activation_func = ACT2FN[hidden_act]
self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size,
self.hidden_size,
self.output_parallel_mode,
self.weight_parallel_mode,
# self.output_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias)
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.dropout = nn.Dropout(hidden_dropout_prob)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
# def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
# return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
@ -331,37 +367,46 @@ class ViTHead3D(nn.Module):
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
dtype: dtype = None,
bias: bool = True):
bias: bool = True,
init_method: str = 'torch'):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
# self.depth = get_depth_from_env()
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
# self.weight_parallel_mode)
self.in_features = in_features
self.num_classes = num_classes
out_features = math.ceil(self.num_classes /
(self.depth**2)) * (self.depth**2)
self.num_classes_per_partition = divide(self.num_classes, self.depth)
self.linear = Linear3D(self.in_features,
out_features,
self.input_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
# out_features = math.ceil(self.num_classes /
# (self.depth**2)) * (self.depth**2)
# self.num_classes_per_partition = divide(self.num_classes, self.depth)
self.init_weight = 'torch'
self.init_bias = 'torch'
if init_method == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.linear.groups_for_next_layer()
self.linear = Linear3D(self.in_features,
self.num_classes,
# self.input_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
def forward(self, x: Tensor) -> Tensor:
# [b/q^2, s, h/q] --> [b/q^2, h/q]
x = x[:, 0]
# [b/q^2, h/q] --> [b/q^2, c/q]
x = self.linear(x)
return x[:, :self.num_classes_per_partition]
# return x[:, :self.num_classes_per_partition]
return x
def extra_repr(self):
return 'in_features={}, num_classes={}'.format(self.in_features,

View File

@ -2,19 +2,28 @@
# -*- encoding: utf-8 -*-
import math
import os
from typing import Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.nn.init import init_bias_, init_weight_
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
from torch.nn import init as init
from .._common_utils import divide, set_tensor_parallel_attribute
from ._operation import Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D
from ._utils import get_depth_from_env, get_last_group
from .._common_utils import divide, set_tensor_parallel_attribute_by_size
from ._operation import (Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D, layer_norm_3d,
linear_3d)
from ._utils import (get_depth_from_env, get_last_group,
get_parallel_mode_from_env, swap_in_out_group)
@LAYERS.register_module
@ -22,20 +31,19 @@ class LayerNorm3D(nn.Module):
def __init__(
self,
normalized_shape: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
# input_parallel_mode: ParallelMode,
# weight_parallel_mode: ParallelMode,
eps: float = 1e-12,
dtype: dtype = None,
):
super().__init__()
self.input_parallel_mode = input_parallel_mode
self.weight_parallel_mode = weight_parallel_mode
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
self.depth = get_depth_from_env()
self.normalized_shape = normalized_shape
self.normalized_shape_per_partition = divide(normalized_shape,
self.depth**2)
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition,
@ -49,37 +57,40 @@ class LayerNorm3D(nn.Module):
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
set_tensor_parallel_attribute(self.bias)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
set_tensor_parallel_attribute_by_size(self.weight, self.normalized_shape)
set_tensor_parallel_attribute_by_size(self.bias, self.normalized_shape)
def reset_parameters(self):
nn.init.zeros_(self.bias)
nn.init.ones_(self.weight)
init.zeros_(self.bias)
init.ones_(self.weight)
def forward(self, input_: Tensor) -> Tensor:
'''x = weight * (x - mean) / sqrt(var + eps) + bias'''
# input: [m/q^2, n, h/q]
# [m/q^2, n, 1]
mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
True) / self.normalized_shape
# [m/q^2, n, 1]
var = (input_ - mean).pow(2)
var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
True) / self.normalized_shape
# '''x = weight * (x - mean) / sqrt(var + eps) + bias'''
# # input: [m/q^2, n, h/q]
# # [m/q^2, n, 1]
# mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
# True) / self.normalized_shape
# # [m/q^2, n, 1]
# var = (input_ - mean).pow(2)
# var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
# True) / self.normalized_shape
output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
output = Mul_3D.apply(output, self.weight, self.depth,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
output = Add_3D.apply(output, self.bias, self.depth,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
return output
# output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
# output = Mul_3D.apply(output, self.weight, self.depth,
# self.input_parallel_mode,
# self.weight_parallel_mode,
# self.output_parallel_mode)
# output = Add_3D.apply(output, self.bias, self.depth,
# self.input_parallel_mode,
# self.weight_parallel_mode,
# self.output_parallel_mode)
# return output
return layer_norm_3d.apply(input_, self.weight, self.bias,
self.normalized_shape,
self.variance_epsilon,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
def extra_repr(self):
return '{}, eps={}'.format(self.normalized_shape,
@ -88,33 +99,36 @@ class LayerNorm3D(nn.Module):
@LAYERS.register_module
class Linear3D(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
bias: bool = True,
dtype: dtype = None):
def __init__(
self,
in_features: int,
out_features: int,
# input_parallel_mode: ParallelMode,
# weight_parallel_mode: ParallelMode,
bias: bool = True,
dtype: dtype = None,
init_weight: str = 'torch',
init_bias: str = 'torch'):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.input_parallel_mode = input_parallel_mode
self.weight_parallel_mode = weight_parallel_mode
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
self.with_bias = bias
# self.with_bias = bias
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth**2)
self.out_features_per_partition = divide(out_features, self.depth)
# [k/q, h/q^2]
# [k/q, h/q]
self.weight = Parameter(
torch.empty(self.in_features_per_partition,
self.out_features_per_partition,
device=get_current_device(),
dtype=dtype))
# [h/q^2]
# [h/q]
if bias:
self.bias = Parameter(
torch.zeros(self.out_features_per_partition,
@ -123,49 +137,54 @@ class Linear3D(nn.Module):
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
swap_in_out_group()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
set_tensor_parallel_attribute_by_size(self.weight, self.in_features * self.out_features)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
set_tensor_parallel_attribute_by_size(self.bias, self.out_features)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.output_parallel_mode, self.weight_parallel_mode
def reset_parameters(self):
def reset_parameters(self, init_weight, init_bias) -> None:
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
fan_in, fan_out = self.in_features, self.out_features
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
# init weight
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
nn.init.uniform_(self.weight, -bound, bound)
init_weight_(self.weight, fan_in, fan_out, init_method=init_weight)
dist.broadcast(self.weight,
src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
# init bias
if self.with_bias:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
nn.init.uniform_(self.bias, -bound, bound)
if self.bias is not None:
init_bias_(self.bias, fan_in, init_method=init_bias)
dist.broadcast(self.bias,
src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
dist.broadcast(self.bias,
src=output_src_rank,
group=gpc.get_group(self.output_parallel_mode))
def forward(self, input_: Tensor) -> Tensor:
# input: [m/q^2, n, k/q]
# output: [m/q^2, n, h/q]
output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
# # input: [m/q^2, n, k/q]
# # output: [m/q^2, n, h/q]
# output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
# self.input_parallel_mode,
# self.weight_parallel_mode,
# self.output_parallel_mode)
if self.with_bias:
output = Add_3D.apply(output, self.bias, self.depth,
self.output_parallel_mode,
self.weight_parallel_mode,
self.input_parallel_mode)
return output
# if self.bias is not None:
# output = Add_3D.apply(output, self.bias, self.depth,
# self.output_parallel_mode,
# self.weight_parallel_mode,
# self.input_parallel_mode)
# return output
return linear_3d.apply(input_, self.weight, self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(

View File

@ -1,3 +0,0 @@
from .layers import ViTBlock
__all__ = ['ViTBlock']

View File

@ -1,59 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
@LAYERS.register_module
class ViTBlock(nn.Module):
"""Vision Transformer block
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
# x_ = x
# x_ = self.norm1(x_)
# if self.checkpoint:
# x_ = checkpoint(self.attn, x_)
# else:
# x_ = self.attn(x_)
# x_ = self.drop_path(x_)
# x = x + x_
#
# x_ = x
# x_ = self.norm2(x_)
# if self.checkpoint:
# x_ = checkpoint(self.mlp, x_)
# else:
# x_ = self.mlp(x_)
# x_ = self.drop_path(x_)
# x = x + x_
return x

View File

@ -1,5 +0,0 @@
from .basic_block import ResNetBasicBlock
from .bottleneck import ResNetBottleneck
from .reslayer import ResLayer
__all__ = ['ResLayer', 'ResNetBottleneck', 'ResNetBasicBlock']

View File

@ -1,64 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3
@LAYERS.register_module
class ResNetBasicBlock(nn.Module):
"""Basic ResNet block
"""
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out

View File

@ -1,69 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3, conv1x1
@LAYERS.register_module
class ResNetBottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out

View File

@ -1,15 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

View File

@ -1,63 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.registry import LAYERS
from .conv import conv1x1
@LAYERS.register_module
class ResLayer(nn.Module):
def __init__(self,
block_type: str,
norm_layer_type: str,
inplanes: int,
planes: int,
blocks: int,
groups: int,
base_width: int,
stride: int = 1,
dilation: int = 1,
dilate: bool = False,
):
super().__init__()
self.block = LAYERS.get_module(block_type)
self.norm_layer = LAYERS.get_module(norm_layer_type)
self.inplanes = inplanes
self.planes = planes
self.blocks = blocks
self.groups = groups
self.dilation = dilation
self.base_width = base_width
self.dilate = dilate
self.stride = stride
self.layer = self._make_layer()
def _make_layer(self):
norm_layer = self.norm_layer
downsample = None
previous_dilation = self.dilation
if self.dilate:
self.dilation *= self.stride
self.stride = 1
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
norm_layer(self.planes * self.block.expansion),
)
layers = []
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = self.planes * self.block.expansion
for _ in range(1, self.blocks):
layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)

View File

@ -1,7 +0,0 @@
from .layers import (VanillaViTBlock, VanillaViTMLP, VanillaViTPatchEmbedding,
VanillaViTAttention, VanillaViTDropPath, VanillaViTHead)
__all__ = [
'VanillaViTBlock', 'VanillaViTMLP', 'VanillaViTPatchEmbedding',
'VanillaViTAttention', 'VanillaViTDropPath', 'VanillaViTHead'
]

View File

@ -1,4 +1,3 @@
from .base_loss import BaseLoss
from .cross_entropy_2d import CrossEntropyLoss2D
from .cross_entropy_2p5d import CrossEntropyLoss2p5D
from .cross_entropy_3d import CrossEntropyLoss3D

View File

@ -1,13 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseLoss(ABC):
"""Absctract loss class
"""
@abstractmethod
def calc_loss(self, *args, **kwargs):
pass

View File

@ -1,120 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_1d._utils import vocab_range_from_per_partition_vocab_size
class _VocabParallelCrossEntropy_1D(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
vocab_start_index, vocab_end_index = vocab_range_from_per_partition_vocab_size(
partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (
1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
class LmLoss1D(_Loss):
def forward(self, lm_logits, lm_labels, loss_mask):
lm_loss = _VocabParallelCrossEntropy_1D.apply(lm_logits, lm_labels)
lm_loss = torch.sum(
lm_loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
return lm_loss
class SopLoss1D(_Loss):
def forward(self, sop_logits, sentence_order):
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
return sop_loss
class BERTDualHeadLoss(_Loss):
def __init__(self):
self.lm_loss = LmLoss1D()
self.sop_loss = SopLoss1D()
def forward(self, lm_logits, sop_logits, lm_labels, loss_mask, sentence_order):
lm_loss = self.lm_loss(lm_logits, lm_labels, loss_mask)
sop_loss = self.sop_loss(sop_logits, sentence_order)
return lm_loss + sop_loss

View File

@ -7,18 +7,18 @@ from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ###
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets):
# logits: [b/q, h/q]
# labels: [b/q]
# loss: [b/q]
# vocab_parallel_logits: [b/q, s, v/q]
# target: [b/q, s]
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(
logits_max,
@ -58,6 +58,7 @@ class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
return loss
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
@ -91,12 +92,14 @@ class _ReduceByColumn(torch.autograd.Function):
return input_
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
return input_
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
return grad_output

View File

@ -1,32 +1,20 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import torch
import torch.distributed as dist
from torch.nn.modules.loss import _Loss
from colossalai.communication import all_gather
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_3d._operation import Reduce_3D
from colossalai.nn.layer.parallel_3d._utils import get_last_group, get_depth_from_env
from colossalai.nn.layer.parallel_3d._utils import (get_depth_from_env,
get_last_group,
get_parallel_mode_from_env)
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
def accuracy_3d(output, target, input_parallel_mode, weight_parallel_mode):
depth = get_depth_from_env()
output_parallel_mode = get_last_group(input_parallel_mode,
weight_parallel_mode)
j = gpc.get_local_rank(input_parallel_mode)
i = gpc.get_local_rank(weight_parallel_mode)
target = torch.chunk(target, depth, dim=0)[i]
target = torch.chunk(target, depth, dim=0)[j]
output = all_gather(output, -1, output_parallel_mode)
prediction = torch.argmax(output, dim=-1)
correct = torch.sum(prediction == target)
dist.all_reduce(correct, group=gpc.get_group(input_parallel_mode))
dist.all_reduce(correct, group=gpc.get_group(weight_parallel_mode))
return correct.item()
from torch.nn.modules.loss import _Loss
class _ParallelCrossEntropyLossFunction_3D(torch.autograd.Function):
@ -112,16 +100,18 @@ class CrossEntropyLoss3D(_Loss):
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self,
input_parallel_mode,
weight_parallel_mode,
reduction=True):
def __init__(
self,
# input_parallel_mode,
# weight_parallel_mode,
reduction=True,
label_smoothing=0.0):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = input_parallel_mode
self.weight_parallel_mode = weight_parallel_mode
self.output_parallel_mode = get_last_group(input_parallel_mode,
weight_parallel_mode)
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
self.input_rank = gpc.get_local_rank(self.input_parallel_mode)
self.weight_rank = gpc.get_local_rank(self.weight_parallel_mode)
self.reduction_mean = reduction
@ -141,53 +131,53 @@ class CrossEntropyLoss3D(_Loss):
return loss
@LOSSES.register_module
class LabelSmoothingCrossEntropy3D(_Loss):
"""
NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy
# @LOSSES.register_module
# class LabelSmoothingCrossEntropy3D(_Loss):
# """
# NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy
:param input_parallel_mode: parallel mode for input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode for weight
:type weight_parallel_mode: ParallelMode
:param smoothing: label smoothing value, defaults to 0.1
:type smoothing: float
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self,
input_parallel_mode,
weight_parallel_mode,
smoothing=0.1,
reduction=True):
super().__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing
self.depth = get_depth_from_env()
self.input_parallel_mode = input_parallel_mode
self.weight_parallel_mode = weight_parallel_mode
self.output_parallel_mode = get_last_group(input_parallel_mode,
weight_parallel_mode)
self.reduction_mean = reduction
# :param input_parallel_mode: parallel mode for input tensor
# :type input_parallel_mode: ParallelMode
# :param weight_parallel_mode: parallel mode for weight
# :type weight_parallel_mode: ParallelMode
# :param smoothing: label smoothing value, defaults to 0.1
# :type smoothing: float
# :param reduction: whether to average the loss, defaults to True
# :type reduction: bool, optional
# """
# def __init__(self,
# input_parallel_mode,
# weight_parallel_mode,
# smoothing=0.1,
# reduction=True):
# super().__init__()
# assert smoothing < 1.0
# self.smoothing = smoothing
# self.confidence = 1. - smoothing
# self.depth = get_depth_from_env()
# self.input_parallel_mode = input_parallel_mode
# self.weight_parallel_mode = weight_parallel_mode
# self.output_parallel_mode = get_last_group(input_parallel_mode,
# weight_parallel_mode)
# self.reduction_mean = reduction
def forward(self, logits, targets):
# split label partition from the entire batch
j = gpc.get_local_rank(self.input_parallel_mode)
i = gpc.get_local_rank(self.weight_parallel_mode)
targets = torch.chunk(targets, self.depth, dim=0)[i]
targets = torch.chunk(targets, self.depth, dim=0)[j]
exp_logits = torch.exp(logits)
sum_exp_logits = Sum3D.apply(exp_logits, -1, depth,
self.output_parallel_mode, False)
log_probs = torch.log(sum_exp_logits) - logits
nll_loss = _ParallelCrossEntropyLossFunction_3D.apply(
logits, targets, self.depth, self.output_parallel_mode)
smooth_loss = -log_probs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
if self.reduction_mean:
loss = loss.sum()
loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
loss /= batch_size
return loss
# def forward(self, logits, targets):
# # split label partition from the entire batch
# j = gpc.get_local_rank(self.input_parallel_mode)
# i = gpc.get_local_rank(self.weight_parallel_mode)
# targets = torch.chunk(targets, self.depth, dim=0)[i]
# targets = torch.chunk(targets, self.depth, dim=0)[j]
# exp_logits = torch.exp(logits)
# sum_exp_logits = Sum3D.apply(exp_logits, -1, depth,
# self.output_parallel_mode, False)
# log_probs = torch.log(sum_exp_logits) - logits
# nll_loss = _ParallelCrossEntropyLossFunction_3D.apply(
# logits, targets, self.depth, self.output_parallel_mode)
# smooth_loss = -log_probs.mean(dim=-1)
# loss = self.confidence * nll_loss + self.smoothing * smooth_loss
# if self.reduction_mean:
# loss = loss.sum()
# loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
# loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
# loss /= batch_size
# return loss

View File

@ -48,8 +48,10 @@ class DelayerScheduler(_LRScheduler):
if self.finished:
if epoch is None:
self.after_scheduler.step(None)
self._last_lr = self.after_scheduler.get_last_lr()
else:
self.after_scheduler.step(epoch - self.delay_epochs)
self._last_lr = self.after_scheduler.get_last_lr()
else:
return super(DelayerScheduler, self).step(epoch)
@ -66,6 +68,7 @@ class WarmupScheduler(_LRScheduler):
:param last_epoch: The index of last epoch, defaults to -1
:type last_epoch: int, optional
"""
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
self.warmup_epochs = int(warmup_epochs)
self.after_scheduler = after_scheduler
@ -85,8 +88,10 @@ class WarmupScheduler(_LRScheduler):
if self.finished:
if epoch is None:
self.after_scheduler.step(None)
self._last_lr = self.after_scheduler.get_last_lr()
else:
self.after_scheduler.step(epoch - self.warmup_epochs)
self._last_lr = self.after_scheduler.get_last_lr()
else:
return super().step(epoch)
@ -136,7 +141,9 @@ class WarmupDelayerScheduler(_LRScheduler):
if self.finished:
if epoch is None:
self.after_scheduler.step(None)
self._last_lr = self.after_scheduler.get_last_lr()
else:
self.after_scheduler.step(epoch - self.warmup_epochs)
self._last_lr = self.after_scheduler.get_last_lr()
else:
return super().step(epoch)

View File

@ -12,7 +12,6 @@ class MultiStepLR(_MultiStepLR):
number of epoch reaches one of the milestones. Notice that such decay can
happen simultaneously with other changes to the learning rate from outside
this scheduler. When last_epoch=-1, sets initial lr as lr.
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@ -34,7 +33,6 @@ class MultiStepLR(_MultiStepLR):
@LR_SCHEDULERS.register_module
class MultiStepWarmupLR(WarmupScheduler):
"""Multi-step laerning rate scheduler with warmup.
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps

View File

@ -12,28 +12,21 @@ class OneCycleLR(_OneCycleLR):
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
This scheduler is not chainable.
Note also that the total number of steps in the cycle can be determined in one
of two ways (listed in order of precedence):
#. A value for total_steps is explicitly provided.
#. A number of epochs (epochs) and a number of steps per epoch
(steps_per_epoch) are provided.
In this case, the number of total steps is inferred by
total_steps = epochs * steps_per_epoch
You must either provide a value for total_steps or provide a value for both
epochs and steps_per_epoch.
The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
claims that "unpublished work has shown even better results by using only two phases". To
mimic the behaviour of the original paper instead, set ``three_phase=True``.
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@ -71,7 +64,6 @@ class OneCycleLR(_OneCycleLR):
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning, defaults to -1
:type last_epoch: int, optional
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120
"""

View File

@ -7,7 +7,6 @@ from .delayed import WarmupScheduler
@LR_SCHEDULERS.register_module
class PolynomialLR(_LRScheduler):
"""Polynomial learning rate scheduler.
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@ -43,7 +42,6 @@ class PolynomialLR(_LRScheduler):
@LR_SCHEDULERS.register_module
class PolynomialWarmupLR(WarmupScheduler):
"""Polynomial learning rate scheduler with warmup.
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps

View File

@ -10,7 +10,6 @@ from colossalai.registry import LR_SCHEDULERS
class LambdaLR(_LambdaLR):
"""Sets the learning rate of each parameter group to the initial lr
times a given function. When last_epoch=-1, sets initial lr as lr.
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@ -33,7 +32,6 @@ class LambdaLR(_LambdaLR):
class MultiplicativeLR(_MultiplicativeLR):
"""Multiply the learning rate of each parameter group by the factor given
in the specified function. When last_epoch=-1, sets initial lr as lr
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@ -58,7 +56,6 @@ class StepLR(_StepLR):
step_size epochs. Notice that such decay can happen simultaneously with
other changes to the learning rate from outside this scheduler. When
last_epoch=-1, sets initial lr as lr
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps
@ -82,7 +79,6 @@ class StepLR(_StepLR):
class ExponentialLR(_ExponentialLR):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param total_steps: number of total training steps

View File

@ -1,3 +1,3 @@
from .base_model import BaseModel
from .vanilla_resnet import VanillaResNet
from .vision_transformer import *
from .model_from_config import ModelFromConfig
__all__ = ['ModelFromConfig']

View File

@ -8,10 +8,10 @@ import torch.nn as nn
from colossalai.builder import build_layer
class BaseModel(nn.Module, ABC):
class ModelFromConfig(nn.Module, ABC):
def __init__(self):
super(BaseModel, self).__init__()
super(ModelFromConfig, self).__init__()
self.layers = nn.ModuleList()
self.layers_cfg = []
@ -32,7 +32,6 @@ class BaseModel(nn.Module, ABC):
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)

View File

@ -1,3 +0,0 @@
from .resnet import VanillaResNet
__all__ = ['VanillaResNet']

View File

@ -1,163 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from colossalai.registry import MODELS
from ..base_model import BaseModel
@MODELS.register_module
class VanillaResNet(BaseModel):
"""ResNet from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""
def __init__(
self,
num_cls: int,
block_type: str,
layers: List[int],
norm_layer_type: str = 'BatchNorm2d',
in_channels: int = 3,
groups: int = 1,
width_per_group: int = 64,
zero_init_residual: bool = False,
replace_stride_with_dilation: Optional[List[bool]] = None,
dilations=(1, 1, 1, 1)
) -> None:
super().__init__()
self.inplanes = 64
self.zero_init_residual = zero_init_residual
self.blocks = layers
self.block_expansion = LAYERS.get_module(block_type).expansion
self.dilations = dilations
self.reslayer_common_cfg = dict(
type='ResLayer',
block_type=block_type,
norm_layer_type=norm_layer_type,
groups=groups,
base_width=width_per_group
)
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.layers_cfg = [
# conv1
dict(type='Conv2d',
in_channels=in_channels,
out_channels=self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False),
# bn1
dict(
type=norm_layer_type,
num_features=self.inplanes
),
# relu
dict(
type='ReLU',
inplace=True
),
# maxpool
dict(
type='MaxPool2d',
kernel_size=3,
stride=2,
padding=1
),
# layer 1
dict(
inplanes=self.inplanes,
planes=64,
blocks=self.blocks[0],
dilation=self.dilations[0],
**self.reslayer_common_cfg
),
# layer 2
dict(
inplanes=64 * self.block_expansion,
planes=128,
blocks=self.blocks[1],
stride=2,
dilate=replace_stride_with_dilation[0],
dilation=self.dilations[1],
**self.reslayer_common_cfg
),
# layer 3
dict(
inplanes=128 * self.block_expansion,
planes=256,
blocks=layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
dilation=self.dilations[2],
**self.reslayer_common_cfg
),
# layer 4
dict(
inplanes=256 * self.block_expansion,
planes=512,
blocks=layers[3], stride=2,
dilate=replace_stride_with_dilation[2],
dilation=self.dilations[3],
**self.reslayer_common_cfg
),
# avg pool
dict(
type='AdaptiveAvgPool2d',
output_size=(1, 1)
),
# flatten
dict(
type='LambdaWrapper',
func=lambda mod, x: torch.flatten(x, 1)
),
# linear
dict(
type='Linear',
in_features=512 * self.block_expansion,
out_features=num_cls
)
]
def forward(self, x: Tensor):
for layer in self.layers:
x = layer(x)
return x,
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, LAYERS.get_module('ResNetBottleneck')):
# type: ignore[arg-type]
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')):
# type: ignore[arg-type]
nn.init.constant_(m.bn2.weight, 0)

View File

@ -1,3 +0,0 @@
from .vision_transformer import VisionTransformerFromConfig
__all__ = ['VisionTransformerFromConfig']

View File

@ -1,14 +1,10 @@
from .fp16_optimizer import FP16Optimizer
from .colossalai_optimizer import ColossalaiOptimizer
from .fused_adam import FusedAdam
from .fused_lamb import FusedLAMB
from .fused_sgd import FusedSGD
from .lamb import Lamb
from .lars import Lars
from .zero_redundancy_optimizer_level_1 import ZeroRedundancyOptimizer_Level_1
from .zero_redundancy_optimizer_level_2 import ZeroRedundancyOptimizer_Level_2
from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
__all__ = [
'ZeroRedundancyOptimizer_Level_1', 'ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3',
'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'FP16Optimizer', 'Lars'
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars'
]

View File

@ -1,168 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch._six import inf
try:
import colossal_C
except:
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
from ..multi_tensor_apply import multi_tensor_applier
from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
def _calc_l2_norm(grads):
norm = 0.0
if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
colossal_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False # no per-parameter norm
)
return norm
def _calc_lp(grads, norm_type):
norm = 0.0
for grad in grads:
grad_norm = torch.norm(grad, norm_type)
norm += grad_norm ** norm_type
return norm
# ======== Gradient Clipping =========
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
params = []
for param in parameters:
if param.grad is not None:
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor'
params.append(param)
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
# Calculate norm.
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in params)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if gpc.is_initialized(ParallelMode.TENSOR):
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.TENSOR))
total_norm = total_norm_cuda[0].item()
else:
tensor_parallel_grads = []
no_tensor_parallel_grads = []
for p in params:
if is_model_parallel_parameter(p):
tensor_parallel_grads.append(p.grad.data)
else:
no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0:
tensor_parallel_norm = _calc_l2_norm(
tensor_parallel_grads) ** norm_type
no_tensor_parallel_norm = _calc_l2_norm(
no_tensor_parallel_grads) ** norm_type
else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_grads = _calc_lp(
no_tensor_parallel_grads, norm_type)
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(tensor_parallel_norm,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR))
total_norm = (tensor_parallel_norm +
no_tensor_parallel_norm) ** (1.0 / norm_type)
if type(total_norm) == 'torch.cuda.FloatTensor':
total_norm = total_norm.item()
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale,
dummy_overflow_buf,
[grads, grads],
clip_coeff)
return total_norm
def count_zeros_fp32(parameters):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros = 0.0
for param in parameters:
grad_not_none = param.grad is not None
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_tp_duplicate:
grad = param.grad.detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR))
total_num_zeros = total_num_zeros.item()
return total_num_zeros
def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
for attr in TENSOR_PARALLEL_ATTRIBUTES:
if hasattr(src_tensor, attr):
val = getattr(src_tensor, attr)
setattr(dst_tensor, attr, val)
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, IS_TENSOR_PARALLEL) and
getattr(param, IS_TENSOR_PARALLEL)) or (
gpc.get_local_rank(ParallelMode.TENSOR) == 0)

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from colossalai.utils import clip_grad_norm_fp32
class ColossalaiOptimizer(Optimizer):
def __init__(self, optim: Optimizer):
self.optim = optim
@property
def param_groups(self):
return self.optim.param_groups
@property
def defaults(self):
return self.optim.defaults
def add_param_group(self, *args, **kwargs):
return self.optim.add_param_group(*args, **kwargs)
def step(self, *args, **kwargs):
return self.optim.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
self.optim.zero_grad(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
self.optim.load_state_dict(*args, **kwargs)
def state_dict(self):
return self.optim.state_dict()
def backward(self, loss: Tensor):
loss.backward()
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if max_norm > 0.0:
clip_grad_norm_fp32(model.parameters(), max_norm)

View File

@ -2,7 +2,7 @@
import torch
from colossalai.registry import OPTIMIZERS
from ..multi_tensor_apply import multi_tensor_applier
from colossalai.utils import multi_tensor_applier
@OPTIMIZERS.register_module

View File

@ -2,7 +2,7 @@
import torch
from colossalai.registry import OPTIMIZERS
from ..multi_tensor_apply import multi_tensor_applier
from colossalai.utils import multi_tensor_applier
@OPTIMIZERS.register_module

View File

@ -3,7 +3,7 @@ import torch
from torch.optim.optimizer import Optimizer, required
from colossalai.registry import OPTIMIZERS
from ..multi_tensor_apply import multi_tensor_applier
from colossalai.utils import multi_tensor_applier
@OPTIMIZERS.register_module

View File

@ -1,707 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from collections import defaultdict
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.optim import Optimizer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import OPTIMIZER_WRAPPERS
from colossalai.utils import get_current_device, print_rank_0
def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
sub_partition_high_limit = (sub_partition_id + 1) * sub_partition_size
if sub_partition_high_limit <= flattened_lean_size:
return 0
else:
return min(sub_partition_size, sub_partition_high_limit - flattened_lean_size)
def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_count):
group_paddings = []
flattened_size = sum([tensor.numel() for tensor in tensor_list])
for i in range(sub_partition_count):
padding = get_alignment_padding(flattened_size, i, sub_partition_size)
group_paddings.append(padding)
return group_paddings
def _single_range_check(current_index, start_index, end_index, tensor_size):
offset = 0
if (current_index >= start_index) and (current_index < end_index):
# Fully inside bounds
return True, offset
elif (start_index > current_index) and (start_index < (current_index + tensor_size)):
# Partially contained, compute offset
offset = start_index - current_index
return True, offset
else:
return False, offset
def _range_check(current_index, element_intervals, tensor_size):
results = []
for comm_idx, interval in enumerate(element_intervals):
start_index, end_index = interval
contained, offset = _single_range_check(
current_index, start_index, end_index, tensor_size)
if contained:
results.append((contained, offset, comm_idx))
if len(results) == 0:
return [(False, 0, -1)]
return results
@OPTIMIZER_WRAPPERS.register_module
class ZeroRedundancyOptimizer_Level_1(Optimizer):
"""
ZeroRedundancyOptimizer_Level_1 designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
This version aligns with stage-1 in the paper above.
"""
def __init__(self,
init_optimizer: Optimizer,
dp_parallel_mode: ParallelMode = ParallelMode.DATA,
max_elements_per_comm=5e8,
verbose=False
):
# TODO: this class does not work with fp16 AMP_TYPE.PARALLEL, fix it
assert get_current_device() != 'cpu', 'ZeRO optimizer cannot be used on CPU only'
self.flatten = _flatten_dense_tensors
self.unflatten = _unflatten_dense_tensors
self.optimizer = init_optimizer
self.dp_parallel_mode = dp_parallel_mode
self.verbose = verbose
# for compatibility with pytorch optim
self.defaults = init_optimizer.defaults
# param flattened by groups
self._param_groups = []
self._param_groups_flat = []
# parallel_sub_partitioned_fp16_groups[group-idx] -> [comm-ids] -> [rank-ids]
self.parallel_sub_partitioned_groups = []
# same underlying data as above but viewed as: [groups] -> [rank-ids] -> [comm-ids]
self.parallel_comm_sub_partitioned_groups = []
# param partition info
# parameters in each group that will not be updated by this process directly
self.params_not_local = []
# parameters that will be updated by this process directly
self.params_in_rank_sub_partitions = []
# parameter offsets for parameters in sub-partitions. Parameter
# boundaries may not align with sub-partition boundaries
# so we need to keep track of the offsets
self.params_in_rank_sub_partitions_offsets = []
# number of elements per sub-partition in each group
self.sub_partition_sizes = []
# number of communication intervals for each group
self.num_comm_intervals_per_group = []
self.local_rank = gpc.get_local_rank(self.dp_parallel_mode)
self.partition_count = self.world_size = gpc.get_world_size(
self.dp_parallel_mode)
self.group_paddings = []
self.default_device = self.optimizer.param_groups[0]['params'][0].device
# max elems per param group
self.max_elems_per_comm = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self._param_groups.append(param_group['params'])
# calculate best max elements per comm based to minimize padding
self.max_elems_per_comm.append(
self.best_max_elems_per_comm(
num_elements=sum(t.numel() for t in self._param_groups[i]),
max_elements_per_comm=max_elements_per_comm
)
)
# flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
# RS: create aligned sub-partitions
flat_aligned_params = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=self._param_groups[i],
max_elements_per_comm=self.max_elems_per_comm[i],
)
self._param_groups_flat.append(flat_aligned_params)
updated_params = self.unflatten(self._param_groups_flat[i],
self._param_groups[i])
for p, q in zip(self._param_groups[i], updated_params):
p.data = q.data
# divide the flat weights into near equal partition equal to the data parallel degree
# each process will compute on a different part of the partition
# RS: split into two layer list -> [comm-id] -> [sub-partitions per rank]
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=self._param_groups_flat[i],
max_elements_per_comm=self.max_elems_per_comm[i],
)
self.parallel_comm_sub_partitioned_groups.append(
comm_partitions) # comm -> rank
self.parallel_sub_partitioned_groups.append(
dp_sub_partitions) # rank -> comm
self.sub_partition_sizes.append(sub_partition_size)
self.num_comm_intervals_per_group.append(num_comm_intervals)
# Compute sub_partition paddings
sub_partition_paddings = get_group_alignment_padding(
tensor_list=self._param_groups[i],
sub_partition_size=sub_partition_size,
sub_partition_count=num_comm_intervals * self.partition_count)
self.group_paddings.append(sub_partition_paddings)
# modify optimizer of have flat master weight
param_group['params'] = self.parallel_sub_partitioned_groups[i][self.local_rank]
# RS: divide up the sub-partitions and keep track of offsets for each param
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local = self.get_all_sub_partition_info(
tensor_list=self._param_groups[i],
all_element_intervals=element_intervals,
)
self.params_in_rank_sub_partitions.append(
params_in_rank_sub_partition)
self.params_not_local.append(params_not_local)
self.params_in_rank_sub_partitions_offsets.append(
params_in_rank_sub_partitions_offsets)
self.local_sub_partitions_of_groups = [
group[self.local_rank] for group in self.parallel_sub_partitioned_groups]
self._initialize_optimizer_states()
@property
def state(self):
return self.optimizer.state
@state.setter
def state(self, value):
self.optimizer.state = value
@property
def param_groups(self):
# LSG: return the full param groups instead of local partitions
# of the param groups for compatibility with torch.cuda.amp
param_groups = []
for group_id, group in enumerate(self.optimizer.param_groups):
group_containing_all_param = {
'params': self._param_groups[group_id],
**{k: v for k, v in group.items() if k != 'params'}
}
# LSG: for compatibility with unknown bug with lr scheduler
# TODO: fix this
group_containing_all_param.setdefault('initial_lr', group['lr'])
param_groups.append(group_containing_all_param)
return param_groups
@param_groups.setter
def param_groups(self, value):
self.optimizer.param_groups = value
def _initialize_optimizer_states(self):
for group_idx, group in enumerate(self.local_sub_partitions_of_groups):
for idx, sub_partition_param in enumerate(group):
sub_partition_grad = torch.zeros(int(
self.sub_partition_sizes[group_idx]),
dtype=sub_partition_param.dtype).cuda()
sub_partition_param.grad = sub_partition_grad
self.optimizer.step()
# LSG: comment out for compatibility with torch.cuda.amp
# for group in self.local_sub_partitions_of_groups:
# for idx, sub_partition_param in enumerate(group):
# sub_partition_param.grad = None
def best_max_elems_per_comm(self, num_elements, max_elements_per_comm):
# if we use max-elems-per-comm as is, how many comm intervals will there be
max_comm_intervals = math.ceil(num_elements / max_elements_per_comm)
padding_for_max_comm = (max_elements_per_comm *
max_comm_intervals) - num_elements
# if we use 1 less comm interval how much extra comm padding would be required
min_comm_intervals = num_elements // max_elements_per_comm
if min_comm_intervals == 0:
if self.verbose:
print_rank_0(
f'Using default max_elements_per_comm {max_elements_per_comm}')
return max_elements_per_comm
padding_for_min_comm = math.ceil(
num_elements / (self.world_size * min_comm_intervals))
# choose padding that uses least amount of overhead
if padding_for_max_comm > padding_for_min_comm:
new_max_elements_per_comm = padding_for_min_comm + max_elements_per_comm
if self.verbose:
print_rank_0(
f'Updating max_elements_per_comm from {max_elements_per_comm} -> {new_max_elements_per_comm}')
return new_max_elements_per_comm
else:
if self.verbose:
print_rank_0(
f'Using default max_elements_per_comm {max_elements_per_comm}')
return max_elements_per_comm
def get_data_parallel_sub_partitions(self,
tensor,
max_elements_per_comm,
):
total_num_elements = tensor.numel()
# if total elements is less than our max, revert to splitting into dp partitions
max_elements_per_comm = min(total_num_elements, max_elements_per_comm)
sub_partition_size = int(max_elements_per_comm // self.world_size)
# Ensure partition alignment was done correctly
num_sub_partitions = int(total_num_elements // sub_partition_size)
assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format(total_num_elements,
sub_partition_size)
# Ensure comm interval alignment was done correctly.
num_comm_intervals = int(num_sub_partitions // self.world_size)
assert num_sub_partitions % self.world_size == 0, "{} % {} != 0".format(
num_sub_partitions, self.world_size)
if self.verbose:
print_rank_0("**** partition info:")
print_rank_0(f"\t total_num_elements={total_num_elements}")
print_rank_0(f"\t world_size={self.world_size}")
print_rank_0(f"\t max_elements_per_comm={max_elements_per_comm}")
print_rank_0(f"\t sub_partition_size={sub_partition_size}")
print_rank_0(f"\t num_sub_partitions={num_sub_partitions}")
print_rank_0(f"\t num_comm_intervals={num_comm_intervals}")
print_rank_0("****")
# [comm_id] -> [rank]
comm_partitions = []
for _ in range(num_comm_intervals):
comm_partitions.append([])
start = 0
comm_id = 0
element_intervals = defaultdict(
list) # [rank] -> [(start,end), (start,end), ...]
for idx in range(num_sub_partitions):
rank_id = idx % self.world_size
sub_partition = tensor.narrow(
0, start, sub_partition_size).detach()
element_intervals[rank_id].append(
(start, start + sub_partition_size))
comm_partitions[comm_id].append(sub_partition)
start = start + sub_partition_size
if rank_id == (self.world_size - 1):
comm_id += 1
# [rank] -> [comm_id]
sub_partitions = []
for _ in range(self.world_size):
sub_partitions.append([])
for comm_id, partitions in enumerate(comm_partitions):
for rank_id, partition in enumerate(partitions):
sub_partitions[rank_id].append(partition)
return comm_partitions, sub_partitions, element_intervals, sub_partition_size, num_comm_intervals
def get_all_sub_partition_info(self,
tensor_list,
all_element_intervals,
):
params_not_local = []
# [rank] -> [comm-id] -> [param/offset]
params_in_rank_sub_partition = []
params_in_rank_sub_partitions_offsets = []
for rank in range(self.world_size):
params_in_local_sub_partition = []
local_sub_partition_offsets = []
comm_tensor_list = []
comm_offset_list = []
current_index = 0
prev_comm_idx = 0
for iii, tensor in enumerate(tensor_list):
tensor_size = tensor.numel()
results_list = _range_check(current_index,
all_element_intervals[rank],
tensor_size)
for contained, offset, comm_idx in results_list:
if contained:
if prev_comm_idx != comm_idx:
params_in_local_sub_partition.append(
comm_tensor_list)
comm_tensor_list = []
local_sub_partition_offsets.append(
comm_offset_list)
comm_offset_list = []
comm_tensor_list.append(tensor)
comm_offset_list.append(offset)
prev_comm_idx = comm_idx
elif rank == self.local_rank:
params_not_local.append(tensor)
current_index = current_index + tensor_size
# assert len(comm_tensor_list) > 0
# assert len(comm_offset_list) > 0
params_in_local_sub_partition.append(comm_tensor_list)
local_sub_partition_offsets.append(comm_offset_list)
params_in_rank_sub_partition.append(params_in_local_sub_partition)
params_in_rank_sub_partitions_offsets.append(
local_sub_partition_offsets)
return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local
def get_flat_sub_partitions(self,
comm_tensor_list,
comm_param_offsets,
sub_partition_size,
dtype,
default_device,
num_comm_intervals=None,
return_partition_params=False):
partition_params = []
final_param_offsets = []
flat_sub_partitions = []
for tensor_list, param_offsets in zip(comm_tensor_list, comm_param_offsets):
flat_tensor_list = []
current_size = 0
my_offsets = []
my_params = []
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
tensor.grad = torch.zeros(tensor.size(),
dtype=tensor.dtype,
device=tensor.device)
param = tensor
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
# we need to offset to get to the right element
if i == 0 and param_offsets[i] > 0:
tensor_offset = param_offsets[i]
num_elements = num_elements - tensor_offset
# We don't need all elements of the tensor if this tensor is
# larger than we have space for in our curr sub-partition
if num_elements > (sub_partition_size - current_size):
num_elements = sub_partition_size - current_size
# we need a narrow view of the tensor based on the tensor offset and number of elements that
# we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)).to(dtype))
else:
flat_tensor_list.append(tensor.to(dtype))
my_params.append(param)
# remember offset into partition and #elems for this tensor
my_offsets.append((current_size, num_elements))
current_size = current_size + num_elements
# this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < sub_partition_size:
my_offsets.append((None, None))
my_params.append(None)
if len(tensor_list) == 0:
assert default_device != None
flat_tensor_list.append(
torch.zeros(int(sub_partition_size - current_size),
dtype=dtype,
device=default_device))
else:
flat_tensor_list.append(
torch.zeros(int(sub_partition_size - current_size),
dtype=dtype,
device=tensor_list[0].device))
partition_params.append(my_params) # flat_tensor_list)
final_param_offsets.append(my_offsets)
assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(
len(flat_tensor_list), len(my_offsets))
flat_sub_partitions.append(self.flatten(flat_tensor_list))
if num_comm_intervals is not None and len(
flat_sub_partitions) < num_comm_intervals:
# print("padding w. sub partitions to ensure uniform communication")
device = flat_sub_partitions[0].device
for _ in range(num_comm_intervals - len(flat_sub_partitions)):
flat_sub_partitions.append(
torch.zeros(int(sub_partition_size),
dtype=dtype,
device=device))
partition_params.append([None])
final_param_offsets.append([(None, None)])
if return_partition_params:
assert len(flat_sub_partitions) == len(partition_params)
assert len(partition_params) == len(final_param_offsets), "{} {}".format(len(partition_params),
len(final_param_offsets))
return flat_sub_partitions, partition_params, final_param_offsets
return flat_sub_partitions
def zero_grad(self, set_grads_to_None=False):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self._param_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def free_grad_in_param_list(self, param_list):
for p in param_list:
if isinstance(p, list):
for _p in p:
_p.grad = None
else:
p.grad = None
def flatten_dense_tensors_sub_partition_aligned(self,
tensor_list,
max_elements_per_comm
):
assert max_elements_per_comm >= self.world_size, f"max_elements_per_comm {max_elements_per_comm} < dp {self.world_size}"
num_elements = sum(t.numel() for t in tensor_list)
# Compute aligned partition size based on parameter count
aligned_param_partition_size = math.ceil(
num_elements / self.world_size)
# Compute aligned partition size based on communication size
aligned_comm_partition_size = int(
max_elements_per_comm // self.world_size)
if aligned_param_partition_size <= aligned_comm_partition_size:
sub_partition_count = 1
sub_partition_size = aligned_param_partition_size
else:
sub_partition_count = math.ceil(aligned_param_partition_size /
aligned_comm_partition_size)
sub_partition_size = aligned_comm_partition_size
# Compute required padding for alignment to dp and max_elements_per_comm
padding = (sub_partition_count * sub_partition_size *
self.world_size) - num_elements
if self.verbose:
print_rank_0(
f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}")
print_rank_0(
f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}")
if padding == 0:
aligned_tensor_list = tensor_list
else:
pad_tensor = torch.zeros(padding,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
aligned_tensor_list = tensor_list + [pad_tensor]
flat_tensors = self.flatten(aligned_tensor_list)
return flat_tensors
# def reduce_gradients(self):
# # LSG: this reduce gradients method no longer works
# # after code change, please use DataParallelGradientHandler instead
#
# world_size = gpc.get_world_size(self.parallel_mode)
# local_rank = gpc.get_local_rank(self.parallel_mode)
#
# for i, group in enumerate(self._param_groups):
# num_comm_intervals = self.num_comm_intervals_per_group[i]
# all_sub_partitions = []
# for rank in range(world_size):
# # gsp is list of partitions indexed by comm_idx
# grad_sub_partitions = self.get_flat_sub_partitions(
# comm_tensor_list=self.params_in_rank_sub_partitions[i][rank],
# comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][rank],
# dtype=self.local_sub_partitions_of_groups[i][0].dtype,
# default_device=self.default_device,
# sub_partition_size=self.sub_partition_sizes[i],
# num_comm_intervals=self.num_comm_intervals_per_group[i])
# all_sub_partitions.append(grad_sub_partitions)
#
# assert len(grad_sub_partitions) == num_comm_intervals
#
# local_comm_partitions = []
# for comm_idx in range(num_comm_intervals):
# single_comm_all_partitions = []
# for rank in range(world_size):
# single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx])
#
# for partition in single_comm_all_partitions:
# partition.div_(world_size)
#
# dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
# input_list=single_comm_all_partitions,
# group=gpc.get_group(self.parallel_mode))
def step(self, closure=None):
local_sub_partitions_grad_groups = []
for i, group in enumerate(self._param_groups):
# RS: update free grads w.r.t. sub partitions
# free gradients for all the parameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_local[i])
# create flat gradient partitions for parameters updated by this process
local_grad_sub_partitions = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][self.local_rank],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][self.local_rank],
sub_partition_size=self.sub_partition_sizes[i],
dtype=self.local_sub_partitions_of_groups[i][0].dtype,
num_comm_intervals=self.num_comm_intervals_per_group[i],
default_device=self.default_device)
# RS: update all our local params with sub-partition grads
for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_groups[i]):
sub_partition_param.grad = local_grad_sub_partitions[idx]
# RS: update free grads for sub-partitions
# release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(
self.params_in_rank_sub_partitions[i][self.local_rank])
local_sub_partitions_grad_groups.append(local_grad_sub_partitions)
if closure is None:
loss = self.optimizer.step()
else:
loss = self.optimizer.step(closure=closure)
# RS: clear our sub partition grads
# LSG: not needed as amp is used instead
# get rid of the fp32 gradients. Not needed anymore
# for group in self.local_sub_partitions_of_groups:
# for idx, sub_partition_param in enumerate(group):
# sub_partition_param.grad = None
# RS: all_gather/broadcast sub-partitions in separate comm calls
# gather the updated weights from everyone
for all_sub_partitions in self.parallel_comm_sub_partitioned_groups:
for comm_id, sub_partitions in enumerate(all_sub_partitions):
dist.all_gather(sub_partitions,
sub_partitions[self.local_rank],
group=gpc.get_group(self.dp_parallel_mode))
# TODO: we probably don't need this? just to be safe
for i in range(len(self._param_groups)):
updated_params = self.unflatten(self._param_groups_flat[i],
self._param_groups[i])
for p, q in zip(self._param_groups[i], updated_params):
p.data = q.data
return loss
def _rigid_state_dict(self):
"""Returns a dict that can be loaded for continued training with same DP degree
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
for k, v in self.optimizer.state_dict().items():
state_dict[k] = v
state_dict[
'local_sub_partitions_of_groups'] = self.local_sub_partitions_of_groups
return state_dict
def state_dict(self):
"""
Returns a dict containing the current state of this Optimizer instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
return self._rigid_state_dict()
def load_state_dict(self,
state_dict,
load_optimizer_states=True,
):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
self._rigid_load_state_dict(
state_dict,
load_optimizer_states)
def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
# I think it should actually be ok to reload the optimizer before the model.
state_dict_ = state_dict.copy()
local_sub_partitions_of_groups = state_dict_.pop(
'local_sub_partitions_of_groups')
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict_)
for curr_group, saved_group in zip(self.local_sub_partitions_of_groups,
local_sub_partitions_of_groups):
for curr_param, saved_param in zip(curr_group, saved_group):
curr_param.data.copy_(saved_param.data)

Some files were not shown because too many files have changed in this diff Show More