mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,7 +3,7 @@ import torch.distributed as dist
|
||||
from torch.autograd import Function
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
__all__ = ['DistCrossEntropy', 'cross_entropy_1d']
|
||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d"]
|
||||
|
||||
|
||||
class DistCrossEntropy(Function):
|
||||
@@ -61,8 +61,9 @@ class DistCrossEntropy(Function):
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
|
||||
# extract the x[class] and set the x[other device] to zero
|
||||
pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device),
|
||||
masked_target_1d]
|
||||
pred_logits_1d = logits_2d[
|
||||
torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d
|
||||
]
|
||||
pred_logits_1d = pred_logits_1d.clone().contiguous()
|
||||
pred_logits = pred_logits_1d.view_as(target)
|
||||
pred_logits[mask] = 0.0
|
||||
@@ -102,8 +103,7 @@ class DistCrossEntropy(Function):
|
||||
return grad_logits, None, None
|
||||
|
||||
|
||||
def cross_entropy_1d(vocab_logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
ignore_index: int = -100,
|
||||
process_group: ProcessGroup = None) -> torch.Tensor:
|
||||
def cross_entropy_1d(
|
||||
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
|
||||
) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
|
||||
|
Reference in New Issue
Block a user