mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-01 11:55:23 +00:00
adapted for sequence parallel (#163)
This commit is contained in:
parent
a2e649da39
commit
e2089c5c15
@ -32,7 +32,7 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.optim.step()
|
return self.optim.step()
|
||||||
|
|
||||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||||
pass
|
pass
|
||||||
|
@ -26,6 +26,7 @@ class ParallelMode(Enum):
|
|||||||
|
|
||||||
# sequence parallel
|
# sequence parallel
|
||||||
SEQUENCE = 'sequence'
|
SEQUENCE = 'sequence'
|
||||||
|
SEQUENCE_DP = 'sequence_dp'
|
||||||
|
|
||||||
# 1D Parallel
|
# 1D Parallel
|
||||||
PARALLEL_1D = '1d'
|
PARALLEL_1D = '1d'
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||||
from .initializer_tensor import Initializer_Tensor
|
from .initializer_tensor import Initializer_Tensor
|
||||||
@ -7,6 +8,43 @@ from .process_group_initializer import ProcessGroupInitializer
|
|||||||
from ..parallel_mode import ParallelMode
|
from ..parallel_mode import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
|
@DIST_GROUP_INITIALIZER.register_module
|
||||||
|
class Initializer_Sequence_DP(ProcessGroupInitializer):
|
||||||
|
'''A ProcessGroupInitializer for sequence parallelism all-reduce.
|
||||||
|
|
||||||
|
In Sequence Parallelism, each GPU holds the full copy of model weights,
|
||||||
|
thus, gradient all-reduce occurs across all processes in the same pipeline stage
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.dp_size = self.world_size // self.pipeline_parallel_size
|
||||||
|
self.num_group = self.pipeline_parallel_size
|
||||||
|
|
||||||
|
def init_dist_group(self):
|
||||||
|
'''Initialize Sequence Parallel process groups used for gradient all-reduce.
|
||||||
|
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||||
|
:rtype: tuple
|
||||||
|
'''
|
||||||
|
local_rank = None
|
||||||
|
ranks_in_group = None
|
||||||
|
process_group = None
|
||||||
|
group_world_size = None
|
||||||
|
mode = ParallelMode.SEQUENCE_DP
|
||||||
|
|
||||||
|
for i in range(self.num_group):
|
||||||
|
ranks = [i * self.dp_size + j for j in range(self.dp_size)]
|
||||||
|
group = dist.new_group(ranks)
|
||||||
|
|
||||||
|
if self.rank in ranks:
|
||||||
|
local_rank = ranks.index(self.rank)
|
||||||
|
group_world_size = len(ranks)
|
||||||
|
process_group = group
|
||||||
|
ranks_in_group = ranks
|
||||||
|
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||||
|
|
||||||
|
|
||||||
@DIST_GROUP_INITIALIZER.register_module
|
@DIST_GROUP_INITIALIZER.register_module
|
||||||
class Initializer_Sequence(ProcessGroupInitializer):
|
class Initializer_Sequence(ProcessGroupInitializer):
|
||||||
'''A ProcessGroupInitializer for sequence parallelism.
|
'''A ProcessGroupInitializer for sequence parallelism.
|
||||||
@ -15,13 +53,27 @@ class Initializer_Sequence(ProcessGroupInitializer):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
*args, **kwargs):
|
*args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
# reuse tensor parallel code
|
# reuse tensor parallel initializer code
|
||||||
self._initializer = Initializer_Tensor(*args, **kwargs)
|
self._sequence_initializer = Initializer_Tensor(*args, **kwargs)
|
||||||
|
self._sequence_dp_initializer = Initializer_Sequence_DP(*args, **kwargs)
|
||||||
|
|
||||||
def init_dist_group(self):
|
def init_dist_group(self):
|
||||||
local_rank, group_world_size, process_group, ranks_in_group, mode = self._initializer.init_dist_group()
|
'''Initialize Sequence parallel process groups and assign local_ranks and groups to each gpu.
|
||||||
|
|
||||||
|
Sequence parallelism requires 2 process groups. The first is for model forward where several processes
|
||||||
|
exchange paritial query, key and value embedding to compute self attention values. The second is for
|
||||||
|
all-reduce to synchronize the model parameters.
|
||||||
|
|
||||||
|
:return: 2D tensor parallelism's information
|
||||||
|
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||||
|
'''
|
||||||
|
|
||||||
|
parallel_setting = []
|
||||||
|
|
||||||
|
local_rank, group_world_size, process_group, ranks_in_group, mode = self._sequence_initializer.init_dist_group()
|
||||||
# change mode to sequence
|
# change mode to sequence
|
||||||
mode = ParallelMode.SEQUENCE
|
mode = ParallelMode.SEQUENCE
|
||||||
|
|
||||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
parallel_setting.append((local_rank, group_world_size, process_group, ranks_in_group, mode))
|
||||||
|
parallel_setting.append(self._sequence_dp_initializer.init_dist_group())
|
||||||
|
return parallel_setting
|
||||||
|
@ -77,7 +77,7 @@ class Engine:
|
|||||||
"""
|
"""
|
||||||
self._all_reduce_gradients()
|
self._all_reduce_gradients()
|
||||||
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
|
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
|
||||||
self.optimizer.step()
|
return self.optimizer.step()
|
||||||
|
|
||||||
def backward(self, loss: Tensor):
|
def backward(self, loss: Tensor):
|
||||||
"""Start backward propagation given the loss value computed by a loss function
|
"""Start backward propagation given the loss value computed by a loss function
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
from ._base_gradient_handler import BaseGradientHandler
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
from ._data_parallel_gradient_handler import DataParallelGradientHandler
|
from ._data_parallel_gradient_handler import DataParallelGradientHandler
|
||||||
from ._zero_gradient_handler import ZeROGradientHandler
|
from ._zero_gradient_handler import ZeROGradientHandler
|
||||||
|
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
|
||||||
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
|
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
|
||||||
from ._moe_gradient_handler import MoeGradientHandler
|
from ._moe_gradient_handler import MoeGradientHandler
|
||||||
|
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
|
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
|
||||||
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
|
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
|
||||||
'MoeGradientHandler']
|
'MoeGradientHandler', 'SequenceParallelGradientHandler']
|
||||||
|
@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from functools import total_ordering
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
|
from ...context.parallel_mode import ParallelMode
|
||||||
|
import colossalai
|
||||||
|
|
||||||
|
|
||||||
|
@GRADIENT_HANDLER.register_module
|
||||||
|
class SequenceParallelGradientHandler(BaseGradientHandler):
|
||||||
|
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||||
|
A all-reduce collective communication will be operated in
|
||||||
|
:func:`handle_gradient` among a data parallel group.
|
||||||
|
For better performance, it bucketizes the gradients of all parameters that are
|
||||||
|
the same type to improve the efficiency of communication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def handle_gradient(self):
|
||||||
|
"""A method running a all-reduce operation in a data parallel group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# bucketize and all-reduce
|
||||||
|
buckets = {}
|
||||||
|
|
||||||
|
# Pack the buckets.
|
||||||
|
for param in self._model.parameters():
|
||||||
|
if param.requires_grad and param.grad is not None:
|
||||||
|
tp = param.data.type()
|
||||||
|
if tp not in buckets:
|
||||||
|
buckets[tp] = []
|
||||||
|
buckets[tp].append(param)
|
||||||
|
|
||||||
|
# For each bucket, all-reduce and copy all-reduced grads.
|
||||||
|
for tp in buckets:
|
||||||
|
bucket = buckets[tp]
|
||||||
|
grads = [param.grad.data for param in bucket]
|
||||||
|
coalesced = _flatten_dense_tensors(grads)
|
||||||
|
|
||||||
|
coalesced /= gpc.get_world_size(ParallelMode.SEQUENCE_DP)
|
||||||
|
|
||||||
|
dist.all_reduce(
|
||||||
|
coalesced, group=gpc.get_group(ParallelMode.SEQUENCE_DP))
|
||||||
|
|
||||||
|
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
||||||
|
coalesced, grads)):
|
||||||
|
buf.copy_(synced)
|
@ -222,7 +222,6 @@ class PipelineSchedule(BaseSchedule):
|
|||||||
|
|
||||||
assert forward_only or return_loss, \
|
assert forward_only or return_loss, \
|
||||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||||
|
|
||||||
self.load_batch(data_iter)
|
self.load_batch(data_iter)
|
||||||
num_warmup_microbatches = \
|
num_warmup_microbatches = \
|
||||||
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
||||||
|
@ -17,7 +17,7 @@ from colossalai.core import global_context as gpc
|
|||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import (accumulate_gradient, get_current_device,
|
from colossalai.utils import (accumulate_gradient, get_current_device,
|
||||||
sync_model_param_in_dp, is_using_ddp, is_using_pp)
|
sync_model_param, is_using_ddp, is_using_pp, is_using_sequence)
|
||||||
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
|
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
|
||||||
from colossalai.builder.builder import build_gradient_handler
|
from colossalai.builder.builder import build_gradient_handler
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
@ -187,7 +187,7 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
|||||||
backend: str = 'nccl',
|
backend: str = 'nccl',
|
||||||
seed: int = 1024,
|
seed: int = 1024,
|
||||||
verbose: bool = True):
|
verbose: bool = True):
|
||||||
'''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
'''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||||
from the environment variables set by PyTorch
|
from the environment variables set by PyTorch
|
||||||
|
|
||||||
:param config: config file or config file path are both acceptable
|
:param config: config file or config file path are both acceptable
|
||||||
@ -270,12 +270,15 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||||||
model.to(get_current_device())
|
model.to(get_current_device())
|
||||||
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
|
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
|
||||||
if not moe_env.is_initialized() and not use_zero3:
|
if not moe_env.is_initialized() and not use_zero3:
|
||||||
sync_model_param_in_dp(model)
|
if is_using_sequence():
|
||||||
|
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
||||||
|
elif is_using_ddp():
|
||||||
|
sync_model_param(model, ParallelMode.DATA)
|
||||||
else:
|
else:
|
||||||
print(
|
logger.warning(
|
||||||
"Warning: The parameters of models is not automatically synchronized.\n"
|
"The parameters of models is not automatically synchronized.\n"
|
||||||
"Please make sure that all parameters are the same in data parallel group.",
|
"Please make sure that all parameters are the same in data parallel group.",
|
||||||
flush=True)
|
ranks=[0])
|
||||||
|
|
||||||
# check amp and zero
|
# check amp and zero
|
||||||
fp16_cfg = gpc.config.get('fp16', None)
|
fp16_cfg = gpc.config.get('fp16', None)
|
||||||
@ -339,11 +342,16 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||||||
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
|
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
|
||||||
"added even though not specified in the configuration",
|
"added even though not specified in the configuration",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
|
elif is_using_sequence():
|
||||||
|
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP))
|
||||||
|
if verbose:
|
||||||
|
logger.info(
|
||||||
|
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
|
||||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
|
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(
|
logger.info(
|
||||||
'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
|
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||||
elif is_using_ddp():
|
elif is_using_ddp():
|
||||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||||
if verbose:
|
if verbose:
|
||||||
|
@ -6,6 +6,7 @@ import numbers
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from torch.nn import init
|
from torch.nn import init
|
||||||
|
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
global colossal_layer_norm_cuda
|
global colossal_layer_norm_cuda
|
||||||
@ -15,6 +16,7 @@ colossal_layer_norm_cuda = None
|
|||||||
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||||
|
|
||||||
ctx.normalized_shape = normalized_shape
|
ctx.normalized_shape = normalized_shape
|
||||||
@ -29,6 +31,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
|
|
||||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||||
@ -71,3 +74,6 @@ class MixedFusedLayerNorm(torch.nn.Module):
|
|||||||
|
|
||||||
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
|
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
|
||||||
self.normalized_shape, self.eps)
|
self.normalized_shape, self.eps)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})'
|
||||||
|
@ -6,7 +6,7 @@ JIT_OPTIONS_SET = False
|
|||||||
def set_jit_fusion_options():
|
def set_jit_fusion_options():
|
||||||
"""Set PyTorch JIT layer fusion options.
|
"""Set PyTorch JIT layer fusion options.
|
||||||
"""
|
"""
|
||||||
# LSG: the latest pytorch and CUDA versions may not support
|
# LSG: the latest pytorch and CUDA versions may not support
|
||||||
# the following jit settings
|
# the following jit settings
|
||||||
global JIT_OPTIONS_SET
|
global JIT_OPTIONS_SET
|
||||||
if JIT_OPTIONS_SET == False:
|
if JIT_OPTIONS_SET == False:
|
||||||
|
@ -9,6 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
|||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
|
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
|
||||||
class RingQK(torch.autograd.Function):
|
class RingQK(torch.autograd.Function):
|
||||||
@ -17,6 +18,7 @@ class RingQK(torch.autograd.Function):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd
|
||||||
def forward(ctx,
|
def forward(ctx,
|
||||||
sub_q,
|
sub_q,
|
||||||
sub_k,
|
sub_k,
|
||||||
@ -54,6 +56,7 @@ class RingQK(torch.autograd.Function):
|
|||||||
return attention_score
|
return attention_score
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
sub_q, sub_k, = ctx.saved_tensors
|
sub_q, sub_k, = ctx.saved_tensors
|
||||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||||
@ -64,6 +67,7 @@ class RingQK(torch.autograd.Function):
|
|||||||
grad_output.transpose(2, 1),
|
grad_output.transpose(2, 1),
|
||||||
sub_q
|
sub_q
|
||||||
)
|
)
|
||||||
|
|
||||||
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
|
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||||
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
|
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
|
||||||
grad_k /= local_world_size
|
grad_k /= local_world_size
|
||||||
@ -94,6 +98,7 @@ class RingAV(torch.autograd.Function):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd
|
||||||
def forward(ctx,
|
def forward(ctx,
|
||||||
attention_score,
|
attention_score,
|
||||||
sub_v,
|
sub_v,
|
||||||
@ -131,6 +136,7 @@ class RingAV(torch.autograd.Function):
|
|||||||
return sub_attention_result
|
return sub_attention_result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||||
|
@ -2,15 +2,20 @@
|
|||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import colossalai
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
|
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
|
||||||
from colossalai.registry import LAYERS
|
from colossalai.registry import LAYERS
|
||||||
|
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||||
|
from colossalai.kernel import FusedScaleMaskSoftmax
|
||||||
|
from colossalai.context import seed
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
@ -31,136 +36,144 @@ class TransformerSelfAttentionRing(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
kv_channels,
|
|
||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
attention_dropout,
|
attention_dropout,
|
||||||
|
attention_mask_func,
|
||||||
|
layer_number,
|
||||||
|
apply_query_key_layer_scaling: bool = False,
|
||||||
|
convert_fp16_to_fp32_in_softmax: bool = False,
|
||||||
|
attn_mask_type=AttnMaskType.padding,
|
||||||
|
masked_softmax_fusion=True,
|
||||||
|
fp16=False,
|
||||||
|
bf16=False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax
|
||||||
|
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||||
|
self.attention_mask_func = attention_mask_func
|
||||||
|
self.layer_number = layer_number
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attn_mask_type = attn_mask_type
|
||||||
|
assert self.layer_number > 0
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
projection_size = kv_channels * num_attention_heads
|
if self.apply_query_key_layer_scaling:
|
||||||
self.hidden_size_per_attention_head = projection_size // num_attention_heads
|
self.convert_fp16_to_fp32_in_softmax = True
|
||||||
|
|
||||||
|
assert self.hidden_size % self.num_attention_heads == 0, \
|
||||||
|
'hidden size is not divisible by the number of attention heads'
|
||||||
|
|
||||||
|
self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads
|
||||||
|
|
||||||
self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||||
|
|
||||||
# Strided linear layer.
|
# Strided linear layer.
|
||||||
self.query_key_value = nn.Linear(
|
self.query_key_value = _Linear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3 * projection_size,
|
3 * self.hidden_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# coeff = None
|
self.coeff = None
|
||||||
self.norm_factor = math.sqrt(self.hidden_size)
|
self.norm_factor = math.sqrt(self.hidden_size)
|
||||||
|
|
||||||
# TODO: add apply_query_key_layer_scaling when we have the kernel module
|
if self.apply_query_key_layer_scaling:
|
||||||
# if self.apply_query_key_layer_scaling:
|
self.coeff = layer_number
|
||||||
# coeff = self.layer_number
|
self.norm_factor *= self.coeff
|
||||||
# self.norm_factor *= coeff
|
|
||||||
|
|
||||||
# TODO: add fused scale mask softmax kernel when we have the kernel module
|
self.scale_mask_softmax = FusedScaleMaskSoftmax(
|
||||||
# self.scale_mask_softmax = FusedScaleMaskSoftmax(
|
fp16, bf16,
|
||||||
# self.fp16, self.bf16,
|
self.attn_mask_type,
|
||||||
# self.attn_mask_type,
|
masked_softmax_fusion,
|
||||||
# masked_softmax_fusion,
|
self.attention_mask_func,
|
||||||
# attention_mask_func,
|
self.convert_fp16_to_fp32_in_softmax,
|
||||||
# self.attention_softmax_in_fp32,
|
self.coeff)
|
||||||
# coeff)
|
|
||||||
|
|
||||||
self.attention_dropout = nn.Dropout(attention_dropout)
|
self.attention_dropout = nn.Dropout(attention_dropout)
|
||||||
|
|
||||||
# Output.
|
# Output.
|
||||||
self.dense = nn.Linear(
|
self.dense = _Linear(hidden_size,
|
||||||
projection_size,
|
hidden_size,
|
||||||
hidden_size,
|
bias=True,
|
||||||
bias=True)
|
skip_bias_add=True)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask):
|
def forward(self, hidden_states, attention_mask):
|
||||||
# hidden_states: [sq, b, h]
|
# hidden_states: [sub_seq_len, batch_size, hidden_size]
|
||||||
|
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
|
||||||
sub_seq_length, batch_size, hidden_size = hidden_states.size()
|
sub_seq_length, batch_size, hidden_size = hidden_states.size()
|
||||||
|
|
||||||
# =====================
|
# =====================
|
||||||
# Query, Key, and Value
|
# Query, Key, and Value
|
||||||
# =====================
|
# =====================
|
||||||
|
|
||||||
# Attention heads [sq, b, h] --> [sq, b, (3 * hn * num_heads)]
|
# Attention heads shape change:
|
||||||
|
# [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)]
|
||||||
mixed_x_layer = self.query_key_value(hidden_states)
|
mixed_x_layer = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
# [sq, b, num_heads, 3 * hn] --> 3 [sq, b, num_heads, hn]
|
# [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size]
|
||||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads,
|
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads,
|
||||||
3 * self.hidden_size_per_attention_head)
|
3 * self.hidden_size_per_attention_head)
|
||||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||||
|
|
||||||
# split into query, key and value
|
# split into query, key and value
|
||||||
last_dim = mixed_x_layer.dim() - 1
|
last_dim = mixed_x_layer.dim() - 1
|
||||||
last_dim_value = mixed_x_layer.size()[-1]
|
last_dim_value = mixed_x_layer.size(-1)
|
||||||
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
|
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
|
||||||
'cannot be divided into query, key and value'
|
'cannot be divided into query, key and value'
|
||||||
partition_size = last_dim_value // 3
|
partition_size = last_dim_value // 3
|
||||||
(query_layer, key_layer, value_layer) = torch.split(
|
(query_layer, key_layer, value_layer) = torch.split(
|
||||||
mixed_x_layer, partition_size, dim=last_dim)
|
mixed_x_layer, partition_size, dim=last_dim)
|
||||||
|
|
||||||
# ===================================
|
# attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
|
||||||
# Raw attention scores. [b, num_heads, s, s]
|
|
||||||
# ===================================
|
|
||||||
|
|
||||||
# [b, num_heads, sq, sk]
|
|
||||||
output_size = (query_layer.size(1),
|
output_size = (query_layer.size(1),
|
||||||
query_layer.size(2),
|
query_layer.size(2),
|
||||||
query_layer.size(0),
|
query_layer.size(0),
|
||||||
key_layer.size(0) * self.world_size)
|
key_layer.size(0) * self.world_size)
|
||||||
|
|
||||||
# [sq, b, num_heads, hn] -> [sq, b * num_heads, hn]
|
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
|
||||||
query_layer = query_layer.view(output_size[2],
|
query_layer = query_layer.view(output_size[2],
|
||||||
output_size[0] * output_size[1], -1)
|
output_size[0] * output_size[1], -1)
|
||||||
# [sk, b, num_heads, hn] -> [sk, b * num_heads, hn]
|
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
|
||||||
key_layer = key_layer.view(key_layer.size(0),
|
key_layer = key_layer.view(key_layer.size(0),
|
||||||
output_size[0] * output_size[1], -1)
|
output_size[0] * output_size[1], -1)
|
||||||
|
|
||||||
# [b, sq, sk]
|
# attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
|
||||||
attention_scores = RingQK.apply(
|
attention_scores = RingQK.apply(
|
||||||
# [b * num_heads, sq, hn]
|
query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size]
|
||||||
query_layer.transpose(0, 1).contiguous(),
|
key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size],
|
||||||
key_layer.transpose(0, 1).contiguous(), # [b * num_heads, sk, hn],
|
|
||||||
batch_size,
|
batch_size,
|
||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
sub_seq_length
|
sub_seq_length
|
||||||
)
|
)
|
||||||
|
|
||||||
attention_scores /= self.norm_factor
|
attention_scores /= self.norm_factor
|
||||||
|
|
||||||
# change view to [b, num_heads, sq, sk]
|
# change view to [batch_size, num_heads, sub_seq_len, seq_len]
|
||||||
attention_scores = attention_scores.view(*output_size)
|
attention_scores = attention_scores.view(*output_size)
|
||||||
attention_scores = attention_scores.unsqueeze(1)
|
|
||||||
|
|
||||||
attention_scores = attention_scores + attention_mask
|
|
||||||
attention_probs = F.softmax(attention_scores, dim=-1)
|
|
||||||
attention_probs = attention_probs.squeeze(1)
|
|
||||||
|
|
||||||
|
# change shape to [batch_size, num_heads, sub_seq_len, seq_len]
|
||||||
|
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
|
||||||
# This is actually dropping out entire tokens to attend to, which might
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
# with mpu.get_cuda_rng_tracker().fork():
|
with seed(ParallelMode.TENSOR):
|
||||||
# TODO: check if a rng tracker is needed
|
attention_probs = self.attention_dropout(attention_probs)
|
||||||
attention_probs = self.attention_dropout(attention_probs)
|
|
||||||
|
|
||||||
# context layer shape: [b, num_heads, sq, hn]
|
# context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
|
||||||
output_size = (value_layer.size(1),
|
output_size = (value_layer.size(1),
|
||||||
value_layer.size(2),
|
value_layer.size(2),
|
||||||
query_layer.size(0),
|
query_layer.size(0),
|
||||||
value_layer.size(3))
|
value_layer.size(3))
|
||||||
#
|
|
||||||
# # change view [sk, b * num_heads, hn]
|
# change view [sub_seq_len, batch_size * num_heads, head_size]
|
||||||
value_layer = value_layer.contiguous().view(value_layer.size(0),
|
value_layer = value_layer.contiguous().view(value_layer.size(0),
|
||||||
output_size[0] * output_size[1], -1)
|
output_size[0] * output_size[1], -1)
|
||||||
|
|
||||||
# # change view [b * num_heads, sq, sk]
|
# # change view [b * num_heads, sub_seq_len, seq_len]
|
||||||
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1),
|
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1),
|
||||||
attention_probs.size(2),
|
attention_probs.size(2),
|
||||||
attention_probs.size(3))
|
attention_probs.size(3))
|
||||||
|
|
||||||
# matmul: [b*num_heads, sq, hn]
|
# matmul: [batch_size * num_heads, sub_seq_len, head_size]
|
||||||
# context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
|
||||||
context_layer = RingAV.apply(
|
context_layer = RingAV.apply(
|
||||||
attention_probs,
|
attention_probs,
|
||||||
value_layer.transpose(0, 1).contiguous(),
|
value_layer.transpose(0, 1).contiguous(),
|
||||||
@ -170,19 +183,83 @@ class TransformerSelfAttentionRing(nn.Module):
|
|||||||
sub_seq_length
|
sub_seq_length
|
||||||
)
|
)
|
||||||
|
|
||||||
# # change view [b, num_heads, sq, hn]
|
# change view [batch_size, num_heads, sub_seq_len, head_size]
|
||||||
context_layer = context_layer.view(*output_size)
|
context_layer = context_layer.view(*output_size)
|
||||||
|
|
||||||
# # [b, np, sq, hn] --> [sq, b, np, hn]
|
# [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size]
|
||||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||||
|
|
||||||
# # [sq, b, np, hn] --> [sq, b, hp]
|
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size]
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (
|
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||||
self.hidden_size_per_attention_head * self.num_attention_heads,)
|
self.hidden_size_per_attention_head * self.num_attention_heads,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
# context_layer = context_layer.transpose(1, 0).contiguous()
|
output, bias = self.dense(context_layer)
|
||||||
output = self.dense(context_layer)
|
|
||||||
bias = self.dense.bias
|
|
||||||
|
|
||||||
return output, bias
|
return output, bias
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \
|
||||||
|
f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \
|
||||||
|
f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \
|
||||||
|
f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \
|
||||||
|
f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})'
|
||||||
|
|
||||||
|
|
||||||
|
class _Linear(nn.Module):
|
||||||
|
"""Linear layer with column parallelism.
|
||||||
|
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||||
|
its second dimension as A = [A_1, ..., A_p].
|
||||||
|
Arguments:
|
||||||
|
input_size: first dimension of matrix A.
|
||||||
|
output_size: second dimension of matrix A.
|
||||||
|
bias: If true, add bias
|
||||||
|
init_method: method to initialize weights. Note that bias is always set
|
||||||
|
to zero.
|
||||||
|
stride: For the strided linear layers.
|
||||||
|
keep_master_weight_for_test: This was added for testing and should be
|
||||||
|
set to False. It returns the master weights
|
||||||
|
used for initialization.
|
||||||
|
skip_bias_add: This was added to enable performance optimations where bias
|
||||||
|
can be fused with other elementwise operations. we skip
|
||||||
|
adding bias but instead return it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_size,
|
||||||
|
output_size,
|
||||||
|
bias=True,
|
||||||
|
skip_bias_add=False):
|
||||||
|
super(_Linear, self).__init__()
|
||||||
|
|
||||||
|
# Keep input parameters
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = output_size
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
|
||||||
|
self.weight = Parameter(torch.empty(self.output_size,
|
||||||
|
self.input_size,
|
||||||
|
))
|
||||||
|
nn.init.xavier_normal_(self.weight)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = Parameter(torch.empty(self.output_size))
|
||||||
|
# Always initialize bias to zero.
|
||||||
|
with torch.no_grad():
|
||||||
|
self.bias.zero_()
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
def forward(self, input_):
|
||||||
|
# Matrix multiply.
|
||||||
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
output = F.linear(input_, self.weight, bias)
|
||||||
|
|
||||||
|
if self.skip_bias_add:
|
||||||
|
return output, self.bias
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
|
||||||
|
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from .activation_checkpoint import checkpoint
|
from .activation_checkpoint import checkpoint
|
||||||
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||||
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
|
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
|
||||||
is_using_ddp, is_using_pp, multi_tensor_applier, param_is_not_tensor_parallel_duplicate,
|
is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, param_is_not_tensor_parallel_duplicate,
|
||||||
print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param_in_dp)
|
print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param)
|
||||||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||||
from .data_sampler import DataParallelSampler, get_dataloader
|
from .data_sampler import DataParallelSampler, get_dataloader
|
||||||
from .gradient_accumulation import accumulate_gradient
|
from .gradient_accumulation import accumulate_gradient
|
||||||
@ -10,9 +10,9 @@ from .memory import report_memory_usage
|
|||||||
from .timer import MultiTimer, Timer
|
from .timer import MultiTimer, Timer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0', 'is_tp_rank_0',
|
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
|
||||||
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'conditional_context', 'is_model_parallel_parameter',
|
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context',
|
||||||
'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
||||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
|
'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
|
||||||
|
@ -47,16 +47,16 @@ def free_port():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
def sync_model_param_in_dp(model):
|
def sync_model_param(model, parallel_mode):
|
||||||
'''Make sure data parameters are consistent during Data Parallel Mode
|
'''Make sure data parameters are consistent during Data Parallel Mode
|
||||||
|
|
||||||
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
||||||
'''
|
'''
|
||||||
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||||
dist.broadcast(
|
dist.broadcast(
|
||||||
param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||||
|
|
||||||
|
|
||||||
def is_dp_rank_0():
|
def is_dp_rank_0():
|
||||||
@ -79,6 +79,10 @@ def is_using_pp():
|
|||||||
return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1
|
return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def is_using_sequence():
|
||||||
|
return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def conditional_context(context_manager, enable=True):
|
def conditional_context(context_manager, enable=True):
|
||||||
if enable:
|
if enable:
|
||||||
@ -240,16 +244,20 @@ def count_zeros_fp32(parameters):
|
|||||||
num_zeros = grad.numel() - torch.count_nonzero(grad)
|
num_zeros = grad.numel() - torch.count_nonzero(grad)
|
||||||
total_num_zeros = num_zeros + total_num_zeros
|
total_num_zeros = num_zeros + total_num_zeros
|
||||||
|
|
||||||
|
total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda()
|
||||||
|
|
||||||
# Sum across all model-parallel GPUs.
|
# Sum across all model-parallel GPUs.
|
||||||
ops = []
|
ops = []
|
||||||
ops.append(dist.all_reduce(total_num_zeros,
|
ops.append(dist.all_reduce(total_num_zeros,
|
||||||
op=dist.ReduceOp.SUM,
|
op=dist.ReduceOp.SUM,
|
||||||
group=gpc.get_group(ParallelMode.TENSOR),
|
group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
async_op=True))
|
async_op=True))
|
||||||
ops.append(dist.all_reduce(total_num_zeros,
|
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||||
op=dist.ReduceOp.SUM,
|
ops.append(dist.all_reduce(total_num_zeros,
|
||||||
group=gpc.get_group(ParallelMode.PIPELINE),
|
op=dist.ReduceOp.SUM,
|
||||||
async_op=True))
|
group=gpc.get_group(ParallelMode.PIPELINE),
|
||||||
|
async_op=True))
|
||||||
|
|
||||||
for req in ops:
|
for req in ops:
|
||||||
req.wait()
|
req.wait()
|
||||||
total_num_zeros = total_num_zeros.item()
|
total_num_zeros = total_num_zeros.item()
|
||||||
|
@ -40,8 +40,6 @@ def report_memory_usage(message, logger=None, report_cpu=False):
|
|||||||
:type report_cpu: bool
|
:type report_cpu: bool
|
||||||
:raises EnvironmentError: raise error if no distributed environment has been initialized
|
:raises EnvironmentError: raise error if no distributed environment has been initialized
|
||||||
'''
|
'''
|
||||||
if not gpc.is_initialized(ParallelMode.GLOBAL):
|
|
||||||
raise EnvironmentError("No distributed environment is initialized")
|
|
||||||
|
|
||||||
gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated())
|
gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated())
|
||||||
gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated())
|
gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated())
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
import time
|
import time
|
||||||
|
from typing import Tuple
|
||||||
from .cuda import synchronize
|
from .cuda import synchronize
|
||||||
|
|
||||||
|
|
||||||
@ -8,6 +9,7 @@ class Timer:
|
|||||||
'''
|
'''
|
||||||
A timer object which helps to log the execution times, and provides different tools to assess the times.
|
A timer object which helps to log the execution times, and provides different tools to assess the times.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._started = False
|
self._started = False
|
||||||
self._start_time = time.time()
|
self._start_time = time.time()
|
||||||
@ -129,6 +131,6 @@ class MultiTimer:
|
|||||||
def set_status(self, mode: bool):
|
def set_status(self, mode: bool):
|
||||||
self._on = mode
|
self._on = mode
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Tuple[str, Timer]:
|
||||||
for name, timer in self._timers.items():
|
for name, timer in self._timers.items():
|
||||||
yield name, timer
|
yield name, timer
|
||||||
|
@ -1,48 +1,150 @@
|
|||||||
#!/usr/bin/env python
|
import colossalai
|
||||||
# -*- encoding: utf-8 -*-
|
import colossalai.nn as col_nn
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.initialize import launch
|
import pytest
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from checks_seq.check_layer_seq import *
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.utils import free_port
|
|
||||||
|
|
||||||
|
|
||||||
CONFIG = dict(
|
CONFIG = dict(
|
||||||
parallel=dict(
|
parallel=dict(
|
||||||
pipeline=1,
|
tensor=dict(size=4, mode='sequence')
|
||||||
tensor=dict(mode='sequence', size=4)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_layer():
|
def check_ring_qk(rank, world_size):
|
||||||
check_selfattention()
|
# params
|
||||||
|
batch_size = 4
|
||||||
|
num_heads = 4
|
||||||
|
seq_length = 32
|
||||||
|
attention_head_size = 32
|
||||||
|
sub_seq_length = seq_length // world_size
|
||||||
|
|
||||||
|
# create master tensors
|
||||||
|
q = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
|
||||||
|
k = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
|
||||||
|
dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||||
|
dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||||
|
|
||||||
|
# create distributed tensors
|
||||||
|
sub_q = q.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
|
||||||
|
sub_k = k.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
|
||||||
|
|
||||||
|
# set autograd attributes
|
||||||
|
q.requires_grad = True
|
||||||
|
k.requires_grad = True
|
||||||
|
q.retain_grad()
|
||||||
|
k.retain_grad()
|
||||||
|
sub_q.requires_grad = True
|
||||||
|
sub_k.requires_grad = True
|
||||||
|
sub_q.retain_grad()
|
||||||
|
sub_k.retain_grad()
|
||||||
|
|
||||||
|
# compute master attention scores
|
||||||
|
a = torch.matmul(q, k.transpose(2, 1))
|
||||||
|
|
||||||
|
# compute distributed attention scores
|
||||||
|
ring_qk = colossalai.nn.layer.parallel_sequence.RingQK.apply
|
||||||
|
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
|
||||||
|
|
||||||
|
# check master and distributed attetion scores
|
||||||
|
sub_master_a = a[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||||
|
assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2)
|
||||||
|
|
||||||
|
# run master backward
|
||||||
|
a.retain_grad()
|
||||||
|
a.mean().backward()
|
||||||
|
|
||||||
|
# run distributed backward
|
||||||
|
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||||
|
torch.autograd.backward(sub_a, partial_master_a_grad)
|
||||||
|
|
||||||
|
# check master and distributed grads
|
||||||
|
partial_master_q_grad = q.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||||
|
assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \
|
||||||
|
'attention score cannot match'
|
||||||
|
|
||||||
|
|
||||||
def run_check_sequence(rank, world_size, port):
|
def check_ring_av(rank, world_size):
|
||||||
# init dist
|
# params
|
||||||
launch(config=CONFIG,
|
batch_size = 4
|
||||||
rank=rank,
|
num_heads = 4
|
||||||
world_size=world_size,
|
seq_length = 16
|
||||||
host='localhost',
|
attention_head_size = 32
|
||||||
port=port,
|
sub_seq_length = seq_length // world_size
|
||||||
backend='nccl')
|
|
||||||
logger = get_dist_logger()
|
|
||||||
logger.info('Distributed environment is initialzied.', ranks=[0])
|
|
||||||
|
|
||||||
# check layers
|
# create master tensors
|
||||||
check_layer()
|
a = torch.rand(batch_size*num_heads, seq_length, seq_length).cuda()
|
||||||
|
v = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
|
||||||
|
dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||||
|
dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||||
|
|
||||||
|
# create distributed tensors
|
||||||
|
sub_a = a.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
|
||||||
|
sub_v = v.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
|
||||||
|
|
||||||
|
# set autograd attributes
|
||||||
|
a.requires_grad = True
|
||||||
|
v.requires_grad = True
|
||||||
|
a.retain_grad()
|
||||||
|
v.retain_grad()
|
||||||
|
sub_a.requires_grad = True
|
||||||
|
sub_v.requires_grad = True
|
||||||
|
sub_a.retain_grad()
|
||||||
|
sub_v.retain_grad()
|
||||||
|
|
||||||
|
# compute master attention scores
|
||||||
|
out = torch.matmul(a, v)
|
||||||
|
|
||||||
|
# compute distributed attention scores
|
||||||
|
ring_av = colossalai.nn.layer.parallel_sequence.RingAV.apply
|
||||||
|
sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length)
|
||||||
|
|
||||||
|
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
|
||||||
|
|
||||||
|
# check master and distributed output
|
||||||
|
sub_master_out = out[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||||
|
assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2)
|
||||||
|
|
||||||
|
# # run master backward
|
||||||
|
out.retain_grad()
|
||||||
|
out.mean().backward()
|
||||||
|
|
||||||
|
# # run distributed backward
|
||||||
|
partial_master_out_grad = out.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||||
|
torch.autograd.backward(sub_out, partial_master_out_grad)
|
||||||
|
|
||||||
|
# # check master and distributed grads
|
||||||
|
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
|
||||||
|
assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \
|
||||||
|
'attention output cannot match'
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(rank, world_size):
|
||||||
|
colossalai.launch(
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
config=CONFIG,
|
||||||
|
host='localhost',
|
||||||
|
port=29500
|
||||||
|
)
|
||||||
|
|
||||||
|
# check_ring_qk(rank, world_size)
|
||||||
|
check_ring_av(rank, world_size)
|
||||||
|
|
||||||
|
gpc.destroy()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_sequence():
|
def test_sequence():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_check_sequence, world_size=world_size, port=free_port())
|
run_func = partial(run_test, world_size=world_size)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user