mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-01 06:25:17 +00:00
[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
41
examples/tutorial/sequence_parallel/loss_func/bert_loss.py
Normal file
41
examples/tutorial/sequence_parallel/loss_func/bert_loss.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.logging import get_dist_logger
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from .cross_entropy import vocab_cross_entropy
|
||||
|
||||
|
||||
class BertLoss(nn.Module):
|
||||
|
||||
def forward(self,
|
||||
lm_loss,
|
||||
sop_logits,
|
||||
loss_mask,
|
||||
sentence_order):
|
||||
lm_loss_ = lm_loss.float()
|
||||
loss_mask = loss_mask.float()
|
||||
loss_mask_sum = loss_mask.sum()
|
||||
lm_loss = torch.sum(
|
||||
lm_loss_.view(-1) * loss_mask.reshape(-1))
|
||||
|
||||
lm_loss /= loss_mask_sum
|
||||
|
||||
torch.distributed.all_reduce(
|
||||
lm_loss,
|
||||
group=gpc.get_group(ParallelMode.SEQUENCE)
|
||||
)
|
||||
|
||||
if sop_logits is not None:
|
||||
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
|
||||
sentence_order.view(-1),
|
||||
ignore_index=-1)
|
||||
sop_loss = sop_loss.float()
|
||||
loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
else:
|
||||
sop_loss = None
|
||||
loss = lm_loss
|
||||
|
||||
return loss
|
Reference in New Issue
Block a user