mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-29 04:40:36 +00:00
[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
class _VocabCrossEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, vocab_parallel_logits, target):
|
||||
# Maximum value along vocab dimension across all GPUs.
|
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
|
||||
|
||||
# Subtract the maximum value.
|
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
target_mask = target < 0
|
||||
masked_target = target.clone()
|
||||
masked_target[target_mask] = 0
|
||||
|
||||
# Get predicted-logits = logits[target].
|
||||
# For Simplicity, we convert logits to a 2-D tensor with size
|
||||
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
|
||||
logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1))
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
|
||||
device=logits_2d.device)
|
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
|
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
||||
predicted_logits = predicted_logits_1d.view_as(target)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
exp_logits = vocab_parallel_logits
|
||||
torch.exp(vocab_parallel_logits, out=exp_logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=-1)
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
|
||||
# Store softmax, target-mask and masked-target for backward pass.
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
# Retreive tensors from the forward path.
|
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
# All the inputs have softmax as their gradient.
|
||||
grad_input = softmax
|
||||
# For simplicity, work with the 2D gradient.
|
||||
partition_vocab_size = softmax.size()[-1]
|
||||
grad_2d = grad_input.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
|
||||
device=grad_2d.device)
|
||||
grad_2d[arange_1d, masked_target_1d] -= (
|
||||
1.0 - target_mask.view(-1).float())
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
def vocab_cross_entropy(vocab_logits, target):
|
||||
"""helper function for the cross entropy."""
|
||||
|
||||
return _VocabCrossEntropy.apply(vocab_logits, target)
|
||||
Reference in New Issue
Block a user