mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
[moe] add mixtral dp grad scaling when not all experts are activated
This commit is contained in:
@@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
||||
|
||||
|
||||
class MoeInGradScaler(torch.autograd.Function):
|
||||
class EPGradScalerIn(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient back by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
@@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.ep_size = ep_size
|
||||
ctx.ep_size = ep_size
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
@@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function):
|
||||
return grad, None
|
||||
|
||||
|
||||
class MoeOutGradScaler(torch.autograd.Function):
|
||||
class EPGradScalerOut(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
@@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function):
|
||||
return grad, None
|
||||
|
||||
|
||||
class DPGradScalerIn(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient back by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
|
||||
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
|
||||
ctx.moe_dp_size = moe_dp_size
|
||||
ctx.activated_experts = activated_experts
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
|
||||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.moe_dp_size != ctx.activated_experts:
|
||||
grad.mul_(ctx.activated_experts / ctx.moe_dp_size)
|
||||
return grad, None, None
|
||||
|
||||
|
||||
class DPGradScalerOut(torch.autograd.Function):
|
||||
"""
|
||||
Scale the gradient by the number of experts
|
||||
because the batch size increases in the moe stage
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
|
||||
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
|
||||
ctx.moe_dp_size = moe_dp_size
|
||||
ctx.activated_experts = activated_experts
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
|
||||
assert len(grad_outputs) == 1
|
||||
grad = grad_outputs[0]
|
||||
if ctx.moe_dp_size != ctx.activated_experts:
|
||||
grad.mul_(ctx.moe_dp_size / ctx.activated_experts)
|
||||
return grad, None, None
|
||||
|
||||
|
||||
def _all_to_all(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
|
Reference in New Issue
Block a user