mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[Shardformer] Add parallel output for shardformer models(bloom, falcon) (#5702)
* [pre-commit.ci] auto fixes from pre-commit.com hooks * add parallel cross entropy output for falcon model & fix some typos in bloom.py * fix module name error, self.model -> self.transformers in bloom, falcon model * Fix the overflow bug of distributed cross entropy loss function when training with fp16 * add dtype to parallel cross entropy loss function * fix dtype related typos adn prettify the loss.py * fix grad dtype and update dtype mismatch error * fix typo bugs
This commit is contained in:
@@ -22,6 +22,7 @@ class DistCrossEntropy(Function):
|
||||
ignore_index: int,
|
||||
process_group: ProcessGroup,
|
||||
vocab_size: int,
|
||||
dtype=torch.float32,
|
||||
):
|
||||
r"""
|
||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||
@@ -34,7 +35,7 @@ class DistCrossEntropy(Function):
|
||||
Args:
|
||||
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
||||
[batch_size, seq_len, vocab_size]
|
||||
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
|
||||
target (:class:`torch.Tensor`): The labels of the vocabulary, shape is
|
||||
[batch_size, seq_len]
|
||||
|
||||
Returns:
|
||||
@@ -86,7 +87,7 @@ class DistCrossEntropy(Function):
|
||||
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
|
||||
exp_logits = vocab_logits
|
||||
torch.exp(vocab_logits, out=exp_logits)
|
||||
sum_exp_logits = torch.sum(exp_logits, dim=-1)
|
||||
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
|
||||
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
|
||||
|
||||
# calculate the loss
|
||||
@@ -97,9 +98,10 @@ class DistCrossEntropy(Function):
|
||||
loss = torch.sum(loss).div_(num_non_zero)
|
||||
|
||||
# calculate the softmax
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
|
||||
exp_logits[target == ignore_index] = 0.0
|
||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
||||
ctx.dtype = dtype
|
||||
|
||||
return loss
|
||||
|
||||
@@ -114,11 +116,11 @@ class DistCrossEntropy(Function):
|
||||
partion_vocab_size = grad_logits.shape[-1]
|
||||
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
|
||||
|
||||
update = 1.0 - mask.view(-1).float()
|
||||
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
|
||||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
||||
|
||||
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
||||
return grad_logits, None, None, None, None
|
||||
return grad_logits, None, None, None, None, None
|
||||
|
||||
|
||||
def cross_entropy_1d(
|
||||
@@ -127,5 +129,6 @@ def cross_entropy_1d(
|
||||
ignore_index: int = -100,
|
||||
process_group: ProcessGroup = None,
|
||||
vocab_size: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
|
||||
|
Reference in New Issue
Block a user