mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[shardformer] llama support DistCrossEntropy (#5176)
* fix aaa fix fix fix * fix * fix * test ci * fix ci fix * llama support dist-cross fix fix fix fix fix fix fix fix * fix * fix * fix fix * test ci * test ci * fix * [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878) * Add finetuning Colossal-Llama-2 example * Add finetuning Colossal-Llama-2 example 2 * Add finetuning Colossal-Llama-2 example and support NEFTuning * Add inference example and refine neftune * Modify readme file * update the imports --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com> * llama support dist-cross fix fix fix fix fix fix fix fix * fix * fix * fix fix * test ci * test ci * fix * fix ci * fix ci --------- Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
This commit is contained in:
@@ -78,10 +78,13 @@ class DistCrossEntropy(Function):
|
||||
# calculate the loss
|
||||
# loss = log(sum(exp(x[i]))) - x[class]
|
||||
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
|
||||
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
|
||||
num_non_zero = torch.sum(loss != 0.0)
|
||||
ctx.inv_num_non_zero = 1.0 / num_non_zero
|
||||
loss = torch.sum(loss).div_(num_non_zero)
|
||||
|
||||
# calculate the softmax
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
exp_logits[target == ignore_index] = 0.0
|
||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
||||
|
||||
return loss
|
||||
@@ -89,6 +92,7 @@ class DistCrossEntropy(Function):
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# retrieve the saved tensors
|
||||
grad_output = grad_output * ctx.inv_num_non_zero
|
||||
exp_logits, mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
# use exp logits as the input grad
|
||||
@@ -100,7 +104,7 @@ class DistCrossEntropy(Function):
|
||||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
||||
|
||||
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
||||
return grad_logits, None, None
|
||||
return grad_logits, None, None, None
|
||||
|
||||
|
||||
def cross_entropy_1d(
|
||||
|
Reference in New Issue
Block a user