mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[NFC] polish colossalai/nn/loss/loss_2p5d.py code style (#1553)
This commit is contained in:
@@ -30,6 +30,7 @@ class CrossEntropyLoss2p5D(_Loss):
|
||||
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
|
||||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, reduction=True, *args, **kwargs):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
@@ -127,6 +128,7 @@ class VocabParallelCrossEntropyLoss2p5D(_Loss):
|
||||
Args:
|
||||
reduction (bool, optional): whether to average the loss, defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, reduction=True):
|
||||
super().__init__()
|
||||
self.reduction_mean = reduction
|
||||
|
Reference in New Issue
Block a user