mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
Hotfix/Colossalai layers (#92)
* optimized 1d layer apis; reorganized nn.layer modules; fixed tests * fixed 2.5d runtime issue * reworked split batch, now called in trainer.schedule.load_batch Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@ from torch import nn
|
||||
from torch.nn.modules.loss import *
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
||||
from .loss_2d import CrossEntropyLoss2D
|
||||
from .loss_2p5d import CrossEntropyLoss2p5D
|
||||
from .loss_3d import CrossEntropyLoss3D
|
||||
@@ -14,9 +15,10 @@ _parallel_cross_entropy = {
|
||||
|
||||
|
||||
class CrossEntropyLoss(_Loss):
|
||||
def __init__(self, reduction: bool = True, tensor_parallel: str = None, *args, **kwargs):
|
||||
def __init__(self, reduction: bool = True, *args, **kwargs):
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel in ['None', '1d']:
|
||||
reduction = 'mean' if reduction else 'none'
|
||||
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
|
||||
else:
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
|
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.functional import cross_entropy
|
||||
@@ -20,11 +20,8 @@ class CrossEntropyLoss2D(_Loss):
|
||||
self.loss_kwargs = kwargs
|
||||
|
||||
def forward(self, logits, targets):
|
||||
batch_size = targets.size(0)
|
||||
targets = split_batch_2d(targets)
|
||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = reduce_by_batch_2d.apply(loss)
|
||||
loss /= batch_size
|
||||
loss = loss.mean()
|
||||
loss = reduce_by_batch_2d.apply(loss, True)
|
||||
return loss
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
|
||||
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.functional import cross_entropy
|
||||
@@ -19,11 +19,8 @@ class CrossEntropyLoss2p5D(_Loss):
|
||||
self.loss_kwargs = kwargs
|
||||
|
||||
def forward(self, logits, targets):
|
||||
batch_size = targets.size(0)
|
||||
targets = split_batch_2p5d(targets)
|
||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = reduce_by_batch_2p5d.apply(loss)
|
||||
loss /= batch_size
|
||||
loss = loss.mean()
|
||||
loss = reduce_by_batch_2p5d.apply(loss, True)
|
||||
return loss
|
||||
|
@@ -1,11 +1,10 @@
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d
|
||||
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.functional import cross_entropy
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class CrossEntropyLoss3D(_Loss):
|
||||
"""Cross entropy loss for 3D parallelism
|
||||
@@ -28,11 +27,8 @@ class CrossEntropyLoss3D(_Loss):
|
||||
self.loss_kwargs = kwargs
|
||||
|
||||
def forward(self, logits, targets):
|
||||
batch_size = targets.size(0)
|
||||
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
loss /= batch_size
|
||||
loss = loss.mean()
|
||||
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
|
||||
return loss
|
||||
|
Reference in New Issue
Block a user