mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[shardformer] add Dropout layer support different dropout pattern (#3856)
* add dropout layer, add dropout test * modify seed manager as context manager * add a copy of col_nn.layer * add dist_crossentropy loss; separate module test * polish the code * fix dist crossentropy loss
This commit is contained in:
0
colossalai/shardformer/layer/__init__.py
Normal file
0
colossalai/shardformer/layer/__init__.py
Normal file
97
colossalai/shardformer/layer/_operation.py
Normal file
97
colossalai/shardformer/layer/_operation.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
try:
|
||||
import fused_mix_prec_layer_norm_cuda
|
||||
except:
|
||||
fused_mix_prec_layer_norm_cuda = None
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
r"""Layernorm
|
||||
|
||||
Args:
|
||||
input: input matrix.
|
||||
weight: weight matrix.
|
||||
bias: bias matrix.
|
||||
normalized_shape: input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps: a value added to the denominator for numerical stability
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.eps = eps
|
||||
input_ = input.contiguous()
|
||||
weight_ = weight.contiguous()
|
||||
bias_ = bias.contiguous()
|
||||
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
|
||||
bias_, ctx.eps)
|
||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
= fused_mix_prec_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
||||
class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer execution with asynchronous communication in backprop.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
|
||||
output = torch.matmul(input_, weight.t())
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
|
||||
total_input = input
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
|
105
colossalai/shardformer/layer/dist_crossentropy.py
Normal file
105
colossalai/shardformer/layer/dist_crossentropy.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class DistCrossEntropy(Function):
|
||||
r"""
|
||||
Overwrite the forward and backward function to calculate the cross entropy loss before gather
|
||||
|
||||
Args:
|
||||
Function (:class:`torch.autograd.Function`): default
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
|
||||
r"""
|
||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||
loss = -log(exp(x[class])/sum(exp(x[i]))
|
||||
and can be rewrite as:
|
||||
loss = log(sum(exp(x[i])) - x[class]
|
||||
|
||||
To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i]
|
||||
|
||||
Args:
|
||||
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
||||
[batch_size, seq_len, vocab_size]
|
||||
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
|
||||
[batch_size, seq_len]
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The cross entropy loss
|
||||
"""
|
||||
# get the max
|
||||
logits_max = torch.max(vocab_logits, dim=-1)[0]
|
||||
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX)
|
||||
|
||||
# minus the max to avoid the result of sum of exp is too large and the log is nan
|
||||
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
||||
|
||||
# mask the target in the local device
|
||||
partition_vocab_size = vocab_logits.size()[-1]
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
global_vocab_size = partition_vocab_size * world_size
|
||||
|
||||
# [down, up) => false, other device and -100 => true
|
||||
delta = (global_vocab_size + world_size - 1) // world_size
|
||||
down_shreshold = rank * delta
|
||||
up_shreshold = down_shreshold + delta
|
||||
mask = (target < down_shreshold) | (target >= up_shreshold)
|
||||
masked_target = target.clone() - down_shreshold
|
||||
masked_target[mask] = 0
|
||||
|
||||
# reshape the logist and target
|
||||
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
||||
# reshape the labels to [bath_size * seq_len]
|
||||
logits_2d = vocab_logits.view(-1, partition_vocab_size)
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
|
||||
# extract the x[class] and set the x[other device] to zero
|
||||
pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device),
|
||||
masked_target_1d]
|
||||
pred_logits_1d = pred_logits_1d.clone().contiguous()
|
||||
pred_logits = pred_logits_1d.view_as(target)
|
||||
pred_logits[mask] = 0.0
|
||||
|
||||
# allreduce the get all x(i,y)
|
||||
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM)
|
||||
exp_logits = vocab_logits
|
||||
torch.exp(vocab_logits, out=exp_logits)
|
||||
sum_exp_logits = torch.sum(exp_logits, dim=-1)
|
||||
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM)
|
||||
|
||||
# calculate the loss
|
||||
# loss = log(sum(exp(x[i]))) - x[class]
|
||||
loss = torch.log(sum_exp_logits) - pred_logits
|
||||
loss = torch.sum(loss).div_(loss.numel())
|
||||
|
||||
# caculate the softmax
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# retrieve the saved tensors
|
||||
exp_logits, mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
# use exp logits as the input grad
|
||||
grad_logits = exp_logits
|
||||
partion_vocab_size = grad_logits.shape[-1]
|
||||
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
|
||||
|
||||
update = 1.0 - mask.view(-1).float()
|
||||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
||||
|
||||
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
||||
return grad_logits, None, None
|
||||
|
||||
|
||||
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels)
|
58
colossalai/shardformer/layer/dropout.py
Normal file
58
colossalai/shardformer/layer/dropout.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SeedManager:
|
||||
"""
|
||||
This class is a random state manager to change random state for different random seed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
original_state = torch.cuda.get_rng_state()
|
||||
seed = int(f"{int(time.time())}{os.environ['RANK']}")
|
||||
torch.cuda.manual_seed(int(seed))
|
||||
self.dropout_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(original_state)
|
||||
|
||||
def set_mode(self, rng_state):
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
|
||||
def get_current_mode(self):
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
return current_state
|
||||
|
||||
@contextmanager
|
||||
def dropout_mode(self):
|
||||
"""
|
||||
This is a context manager to change the dropout state and recover the original state.
|
||||
|
||||
Usage:
|
||||
::
|
||||
>>> with _seed_manager.dropout_mode():
|
||||
>>> input = super().forward(input)
|
||||
"""
|
||||
try:
|
||||
current_mode = self.get_current_mode()
|
||||
yield self.set_mode(self.dropout_state)
|
||||
finally:
|
||||
self.dropout_state = self.get_current_mode()
|
||||
self.set_mode(current_mode)
|
||||
|
||||
|
||||
_seed_manager = SeedManager()
|
||||
|
||||
|
||||
class Dropout1D(nn.Dropout):
|
||||
|
||||
def __init__(self, p=0.5, inplace=False):
|
||||
super().__init__(p, inplace)
|
||||
|
||||
def forward(self, input):
|
||||
with _seed_manager.dropout_mode():
|
||||
input = super().forward(input)
|
||||
return input
|
1043
colossalai/shardformer/layer/layers.py
Normal file
1043
colossalai/shardformer/layer/layers.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user