mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
[shardformer] Align bert value (#3907)
* add bert align test, fix dist loss bug * forward and backward align * add ignore index * add shardformer CI * add gather_output optional for user in shardconfig * update readme with optional gather_ouput * add dist crossentropy loss test, remove unused files * remove unused file * remove unused file * rename the file * polish code
This commit is contained in:
@@ -14,7 +14,7 @@ class DistCrossEntropy(Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
|
||||
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int):
|
||||
r"""
|
||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||
loss = -log(exp(x[class])/sum(exp(x[i]))
|
||||
@@ -75,8 +75,8 @@ class DistCrossEntropy(Function):
|
||||
|
||||
# calculate the loss
|
||||
# loss = log(sum(exp(x[i]))) - x[class]
|
||||
loss = torch.log(sum_exp_logits) - pred_logits
|
||||
loss = torch.sum(loss).div_(loss.numel())
|
||||
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
|
||||
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
|
||||
|
||||
# caculate the softmax
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
@@ -101,5 +101,5 @@ class DistCrossEntropy(Function):
|
||||
return grad_logits, None, None
|
||||
|
||||
|
||||
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels)
|
||||
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index)
|
||||
|
Reference in New Issue
Block a user