mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
moved env variables to global variables; (#215)
added branch context; added vocab parallel layers; moved split_batch from load_batch to tensor parallel embedding layers; updated gpt model; updated unit test cases; fixed few collective communicator bugs
This commit is contained in:
@@ -1,25 +1,37 @@
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
||||
from torch import nn
|
||||
from torch.nn.modules.loss import *
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
||||
from .loss_2d import CrossEntropyLoss2D
|
||||
from .loss_2p5d import CrossEntropyLoss2p5D
|
||||
from .loss_3d import CrossEntropyLoss3D
|
||||
from .loss_1d import VocabParallelCrossEntropyLoss1D
|
||||
from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D
|
||||
from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D
|
||||
from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D
|
||||
from .loss_moe import MoeCrossEntropyLoss, MoeLoss
|
||||
|
||||
_parallel_cross_entropy = {
|
||||
'2d': CrossEntropyLoss2D,
|
||||
'2.5d': CrossEntropyLoss2p5D,
|
||||
'3d': CrossEntropyLoss3D
|
||||
'3d': CrossEntropyLoss3D,
|
||||
}
|
||||
|
||||
_vocab_parallel_cross_entropy = {
|
||||
'1d': VocabParallelCrossEntropyLoss1D,
|
||||
'2d': VocabParallelCrossEntropyLoss2D,
|
||||
'2.5d': VocabParallelCrossEntropyLoss2p5D,
|
||||
'3d': VocabParallelCrossEntropyLoss3D,
|
||||
}
|
||||
|
||||
|
||||
class CrossEntropyLoss(_Loss):
|
||||
|
||||
def __init__(self, reduction: bool = True, *args, **kwargs):
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel in ['None', '1d']:
|
||||
if tensor_parallel is not None and env.vocab_parallel:
|
||||
self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
|
||||
elif tensor_parallel is None or tensor_parallel == '1d':
|
||||
reduction = 'mean' if reduction else 'none'
|
||||
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user