ColossalAI/examples/tutorial/sequence_parallel/loss_func/bert_loss.py
Hongxin Liu b5f9e37c70
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
2023-09-18 16:31:06 +08:00

34 lines
1.0 KiB
Python

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger
from .cross_entropy import vocab_cross_entropy
class BertLoss(nn.Module):
def forward(self, lm_loss, sop_logits, loss_mask, sentence_order):
lm_loss_ = lm_loss.float()
loss_mask = loss_mask.float()
loss_mask_sum = loss_mask.sum()
lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1))
lm_loss /= loss_mask_sum
torch.distributed.all_reduce(lm_loss, group=gpc.get_group(ParallelMode.SEQUENCE))
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE)
else:
sop_loss = None
loss = lm_loss
return loss