[moe] add mixtral dp grad scaling when not all experts are activated

This commit is contained in:
botbw
2024-07-12 03:27:20 +00:00
committed by Hongxin Liu
parent e28e05345b
commit 9b9b76bdcd
8 changed files with 98 additions and 42 deletions

View File

@@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
return torch.cumsum(inputs, dim=0) - 1
class MoeInGradScaler(torch.autograd.Function):
class EPGradScalerIn(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
@@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
if ctx is not None:
ctx.ep_size = ep_size
ctx.ep_size = ep_size
return inputs
@staticmethod
@@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function):
return grad, None
class MoeOutGradScaler(torch.autograd.Function):
class EPGradScalerOut(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
@@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function):
return grad, None
class DPGradScalerIn(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
ctx.moe_dp_size = moe_dp_size
ctx.activated_experts = activated_experts
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.moe_dp_size != ctx.activated_experts:
grad.mul_(ctx.activated_experts / ctx.moe_dp_size)
return grad, None, None
class DPGradScalerOut(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
ctx.moe_dp_size = moe_dp_size
ctx.activated_experts = activated_experts
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.moe_dp_size != ctx.activated_experts:
grad.mul_(ctx.moe_dp_size / ctx.activated_experts)
return grad, None, None
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,