mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -11,28 +11,27 @@ from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D
|
||||
from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D
|
||||
|
||||
_parallel_cross_entropy = {
|
||||
'2d': CrossEntropyLoss2D,
|
||||
'2.5d': CrossEntropyLoss2p5D,
|
||||
'3d': CrossEntropyLoss3D,
|
||||
"2d": CrossEntropyLoss2D,
|
||||
"2.5d": CrossEntropyLoss2p5D,
|
||||
"3d": CrossEntropyLoss3D,
|
||||
}
|
||||
|
||||
_vocab_parallel_cross_entropy = {
|
||||
'1d': VocabParallelCrossEntropyLoss1D,
|
||||
'2d': VocabParallelCrossEntropyLoss2D,
|
||||
'2.5d': VocabParallelCrossEntropyLoss2p5D,
|
||||
'3d': VocabParallelCrossEntropyLoss3D,
|
||||
"1d": VocabParallelCrossEntropyLoss1D,
|
||||
"2d": VocabParallelCrossEntropyLoss2D,
|
||||
"2.5d": VocabParallelCrossEntropyLoss2p5D,
|
||||
"3d": VocabParallelCrossEntropyLoss3D,
|
||||
}
|
||||
|
||||
|
||||
class CrossEntropyLoss(_Loss):
|
||||
|
||||
def __init__(self, reduction: bool = True, *args, **kwargs):
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel is not None and env.vocab_parallel:
|
||||
self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
|
||||
elif tensor_parallel is None or tensor_parallel == '1d':
|
||||
reduction = 'mean' if reduction else 'none'
|
||||
elif tensor_parallel is None or tensor_parallel == "1d":
|
||||
reduction = "mean" if reduction else "none"
|
||||
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
|
||||
else:
|
||||
self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
|
||||
|
@@ -9,7 +9,6 @@ from colossalai.legacy.registry import LOSSES
|
||||
|
||||
|
||||
class _VocabParallelCrossEntropy1D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, vocab_parallel_logits, targets, process_group):
|
||||
@@ -61,7 +60,6 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
# Retrieve tensors from the forward path.
|
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
@@ -73,7 +71,7 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
|
||||
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
|
||||
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
|
@@ -50,7 +50,7 @@ class CrossEntropyLoss2D(_Loss):
|
||||
float: the loss between logits and targets.
|
||||
"""
|
||||
targets = split_batch_2d(targets)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.mean()
|
||||
loss = reduce_by_batch_2d(loss, True)
|
||||
@@ -69,9 +69,9 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function):
|
||||
# vocab_parallel_logits: [b/q, s, v/q]
|
||||
# target: [b/q, s]
|
||||
logits_max = torch.max(logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
torch.distributed.all_reduce(
|
||||
logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)
|
||||
)
|
||||
# Subtract the maximum value.
|
||||
# vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
logits = logits - logits_max.unsqueeze(dim=-1)
|
||||
@@ -90,7 +90,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function):
|
||||
end=logits.size()[0],
|
||||
)
|
||||
predicted_logits = logits[arange_1d, masked_target]
|
||||
predicted_logits[target_mask] = 0.
|
||||
predicted_logits[target_mask] = 0.0
|
||||
dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
|
||||
exp_logits = torch.exp(logits)
|
||||
@@ -119,7 +119,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function):
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
|
||||
grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float())
|
||||
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(output_grad.unsqueeze(dim=-1))
|
||||
|
@@ -47,7 +47,7 @@ class CrossEntropyLoss2p5D(_Loss):
|
||||
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
"""
|
||||
targets = split_batch_2p5d(targets)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.mean()
|
||||
loss = reduce_by_batch_2p5d(loss, True)
|
||||
@@ -64,9 +64,9 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
|
||||
# loss: [b/dq]
|
||||
# targets: [b/dq, h/q]
|
||||
logits_max = torch.max(logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
torch.distributed.all_reduce(
|
||||
logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
)
|
||||
# Subtract the maximum value.
|
||||
logits = logits - logits_max.unsqueeze(dim=-1)
|
||||
|
||||
@@ -84,7 +84,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
|
||||
end=logits.size()[0],
|
||||
)
|
||||
predicted_logits = logits[arange_1d, masked_target]
|
||||
predicted_logits[target_mask] = 0.
|
||||
predicted_logits[target_mask] = 0.0
|
||||
dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
|
||||
exp_logits = torch.exp(logits)
|
||||
@@ -113,7 +113,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
|
||||
grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float())
|
||||
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(output_grad.unsqueeze(dim=-1))
|
||||
|
@@ -49,7 +49,7 @@ class CrossEntropyLoss3D(_Loss):
|
||||
"""
|
||||
targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
|
||||
targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.mean()
|
||||
loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
|
||||
@@ -83,7 +83,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
|
||||
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device())
|
||||
predicted_logits = logits[arange_1d, masked_target]
|
||||
predicted_logits = predicted_logits.clone().contiguous().view_as(targets)
|
||||
predicted_logits[target_mask] = 0.
|
||||
predicted_logits[target_mask] = 0.0
|
||||
dist.all_reduce(predicted_logits, group=gpc.get_group(output_parallel_mode))
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
@@ -111,7 +111,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
|
||||
grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float())
|
||||
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
|
||||
input_grad.mul_(output_grad.unsqueeze(dim=-1))
|
||||
|
||||
return input_grad, None, None, None
|
||||
|
Reference in New Issue
Block a user