ColossalAI/examples/community/roberta/pretraining/loss.py
Hongxin Liu 079bf3cb26
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
2023-09-19 14:20:26 +08:00

17 lines
677 B
Python

import torch
__all__ = ["LossForPretraining"]
class LossForPretraining(torch.nn.Module):
def __init__(self, vocab_size):
super(LossForPretraining, self).__init__()
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
self.vocab_size = vocab_size
def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None):
masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
# next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
total_loss = masked_lm_loss # + next_sentence_loss
return total_loss