mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
[moe] add mixtral dp grad scaling when not all experts are activated
This commit is contained in:
parent
2f9bce6686
commit
1b15cc97f5
@ -141,6 +141,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
# set ep_group after super init
|
# set ep_group after super init
|
||||||
# TODO do it in a better way
|
# TODO do it in a better way
|
||||||
self.shard_config.ep_group = self.ep_group
|
self.shard_config.ep_group = self.ep_group
|
||||||
|
self.shard_config.moe_dp_group = self.moe_dp_group
|
||||||
|
self.shard_config.moe_tp_group = self.moe_tp_group
|
||||||
|
|
||||||
self.force_overlap_comm = force_overlap_comm
|
self.force_overlap_comm = force_overlap_comm
|
||||||
|
|
||||||
@ -159,7 +161,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
|
|
||||||
# create groups from submesh
|
# create groups from submesh
|
||||||
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
|
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
|
||||||
# axis 0 is dp, axis 1 is tp, axis 2 is sp
|
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
|
||||||
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||||
|
|
||||||
# hardcode here since we only have 3 axis
|
# hardcode here since we only have 3 axis
|
||||||
@ -188,7 +190,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
assert self.moe_tp_group is None
|
assert self.moe_tp_group is None
|
||||||
self.moe_tp_group = group
|
self.moe_tp_group = group
|
||||||
|
|
||||||
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}")
|
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0])
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||||
return MoECheckpointIO(
|
return MoECheckpointIO(
|
||||||
|
@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
|
|||||||
return torch.cumsum(inputs, dim=0) - 1
|
return torch.cumsum(inputs, dim=0) - 1
|
||||||
|
|
||||||
|
|
||||||
class MoeInGradScaler(torch.autograd.Function):
|
class EPGradScalerIn(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
Scale the gradient back by the number of experts
|
Scale the gradient back by the number of experts
|
||||||
because the batch size increases in the moe stage
|
because the batch size increases in the moe stage
|
||||||
@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
||||||
if ctx is not None:
|
ctx.ep_size = ep_size
|
||||||
ctx.ep_size = ep_size
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function):
|
|||||||
return grad, None
|
return grad, None
|
||||||
|
|
||||||
|
|
||||||
class MoeOutGradScaler(torch.autograd.Function):
|
class EPGradScalerOut(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
Scale the gradient by the number of experts
|
Scale the gradient by the number of experts
|
||||||
because the batch size increases in the moe stage
|
because the batch size increases in the moe stage
|
||||||
@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function):
|
|||||||
return grad, None
|
return grad, None
|
||||||
|
|
||||||
|
|
||||||
|
class DPGradScalerIn(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Scale the gradient back by the number of experts
|
||||||
|
because the batch size increases in the moe stage
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
|
||||||
|
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
|
||||||
|
ctx.moe_dp_size = moe_dp_size
|
||||||
|
ctx.activated_experts = activated_experts
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
|
||||||
|
assert len(grad_outputs) == 1
|
||||||
|
grad = grad_outputs[0]
|
||||||
|
if ctx.moe_dp_size != ctx.activated_experts:
|
||||||
|
grad.mul_(ctx.activated_experts / ctx.moe_dp_size)
|
||||||
|
return grad, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class DPGradScalerOut(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Scale the gradient by the number of experts
|
||||||
|
because the batch size increases in the moe stage
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
|
||||||
|
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
|
||||||
|
ctx.moe_dp_size = moe_dp_size
|
||||||
|
ctx.activated_experts = activated_experts
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
|
||||||
|
assert len(grad_outputs) == 1
|
||||||
|
grad = grad_outputs[0]
|
||||||
|
if ctx.moe_dp_size != ctx.activated_experts:
|
||||||
|
grad.mul_(ctx.moe_dp_size / ctx.activated_experts)
|
||||||
|
return grad, None, None
|
||||||
|
|
||||||
|
|
||||||
def _all_to_all(
|
def _all_to_all(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
input_split_sizes: Optional[List[int]] = None,
|
input_split_sizes: Optional[List[int]] = None,
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
from colossalai.moe.utils import get_activation
|
from colossalai.moe.utils import get_activation
|
||||||
from colossalai.shardformer.layer.utils import Randomizer
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
@ -118,7 +118,7 @@ class MLPExperts(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
||||||
"""
|
"""
|
||||||
x = MoeInGradScaler.apply(x, self.ep_size)
|
x = EPGradScalerIn.apply(x, self.ep_size)
|
||||||
|
|
||||||
e = x.size(1)
|
e = x.size(1)
|
||||||
h = x.size(-1)
|
h = x.size(-1)
|
||||||
@ -157,5 +157,5 @@ class MLPExperts(nn.Module):
|
|||||||
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
|
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
|
||||||
x = x.reshape(inshape)
|
x = x.reshape(inshape)
|
||||||
x = x.transpose(0, 1).contiguous()
|
x = x.transpose(0, 1).contiguous()
|
||||||
x = MoeOutGradScaler.apply(x, self.ep_size)
|
x = EPGradScalerOut.apply(x, self.ep_size)
|
||||||
return x
|
return x
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
from colossalai.moe.utils import get_activation
|
from colossalai.moe.utils import get_activation
|
||||||
from colossalai.shardformer.layer.utils import Randomizer
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
@ -118,7 +118,7 @@ class MLPExperts(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
||||||
"""
|
"""
|
||||||
x = MoeInGradScaler.apply(x, self.ep_size)
|
x = EPGradScalerIn.apply(x, self.ep_size)
|
||||||
|
|
||||||
e = x.size(1)
|
e = x.size(1)
|
||||||
h = x.size(-1)
|
h = x.size(-1)
|
||||||
@ -157,5 +157,5 @@ class MLPExperts(nn.Module):
|
|||||||
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
|
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
|
||||||
x = x.reshape(inshape)
|
x = x.reshape(inshape)
|
||||||
x = x.transpose(0, 1).contiguous()
|
x = x.transpose(0, 1).contiguous()
|
||||||
x = MoeOutGradScaler.apply(x, self.ep_size)
|
x = EPGradScalerOut.apply(x, self.ep_size)
|
||||||
return x
|
return x
|
||||||
|
@ -14,18 +14,23 @@ from transformers.models.mixtral.modeling_mixtral import (
|
|||||||
from transformers.utils import is_flash_attn_2_available, logging
|
from transformers.utils import is_flash_attn_2_available, logging
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens
|
from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||||
|
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||||
|
|
||||||
|
|
||||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(config)
|
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||||
self.setup_process_groups(ep_group, tp_group, moe_tp_group)
|
|
||||||
|
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup):
|
||||||
|
assert tp_group is not None
|
||||||
|
assert moe_dp_group is not None
|
||||||
|
assert ep_group is not None
|
||||||
|
assert moe_tp_group is not None
|
||||||
|
|
||||||
def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
|
|
||||||
# setup ep group
|
# setup ep group
|
||||||
self.ep_size = dist.get_world_size(ep_group)
|
self.ep_size = dist.get_world_size(ep_group)
|
||||||
self.ep_rank = dist.get_rank(ep_group)
|
self.ep_rank = dist.get_rank(ep_group)
|
||||||
@ -40,7 +45,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||||||
|
|
||||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||||
for p in self.experts.parameters():
|
for p in self.experts.parameters():
|
||||||
p.ep_group = ep_group
|
set_moe_tensor_ep_group(p, ep_group)
|
||||||
|
|
||||||
|
# setup moe_dp group
|
||||||
|
self.moe_dp_group = moe_dp_group
|
||||||
|
self.moe_dp_size = moe_dp_group.size()
|
||||||
|
|
||||||
# setup global tp group
|
# setup global tp group
|
||||||
self.tp_group = tp_group
|
self.tp_group = tp_group
|
||||||
@ -50,11 +59,12 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs
|
module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs
|
||||||
) -> "EPMixtralSparseMoeBlock":
|
) -> "EPMixtralSparseMoeBlock":
|
||||||
|
# TODO: better init
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
module.__class__ = EPMixtralSparseMoeBlock
|
module.__class__ = EPMixtralSparseMoeBlock
|
||||||
module.setup_process_groups(ep_group)
|
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
@ -76,36 +86,48 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
|
||||||
|
for i in range(1, self.ep_size):
|
||||||
|
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
|
||||||
|
activate_experts = (activate_experts > 0).float()
|
||||||
|
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||||
|
|
||||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||||
|
|
||||||
if self.tp_group is not None and self.tp_group.size() > 1:
|
if self.tp_group.size() > 1:
|
||||||
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
|
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
|
||||||
|
|
||||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||||
# compute expert output
|
# compute expert output
|
||||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||||
if output_states.size(0) > 0:
|
if output_states.size(0) > 0:
|
||||||
if self.num_experts_per_ep == 1:
|
if self.num_experts_per_ep == 1:
|
||||||
# no need to split
|
# no need to split
|
||||||
expert = self.experts[self.expert_start_idx]
|
expert = self.experts[self.expert_start_idx]
|
||||||
|
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item())
|
||||||
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
||||||
output_states = expert.w2(output_states)
|
output_states = expert.w2(output_states)
|
||||||
|
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item())
|
||||||
else:
|
else:
|
||||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||||
output_states_list = []
|
output_states_list = []
|
||||||
for i, split_states in enumerate(output_states_splits):
|
for i, split_states in enumerate(output_states_splits):
|
||||||
if split_states.size(0) == 0:
|
if split_states.size(0) == 0:
|
||||||
continue
|
continue
|
||||||
|
split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
|
||||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||||
split_states = expert.w2(split_states)
|
split_states = expert.w2(split_states)
|
||||||
|
split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
|
||||||
output_states_list.append(split_states)
|
output_states_list.append(split_states)
|
||||||
output_states = torch.cat(output_states_list)
|
output_states = torch.cat(output_states_list)
|
||||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
|
||||||
|
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||||
|
|
||||||
if self.tp_group is not None and self.tp_group.size() > 1:
|
if self.tp_group.size() > 1:
|
||||||
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
|
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
|
||||||
|
|
||||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||||
|
@ -76,18 +76,6 @@ class MixtralPolicy(Policy):
|
|||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
),
|
),
|
||||||
# SubModuleReplacementDescription( # TODO: enable moe tp parallel
|
|
||||||
# suffix="mlp.gate_proj",
|
|
||||||
# target_module=Linear1D_Col,
|
|
||||||
# ),
|
|
||||||
# SubModuleReplacementDescription(
|
|
||||||
# suffix="mlp.up_proj",
|
|
||||||
# target_module=Linear1D_Col,
|
|
||||||
# ),
|
|
||||||
# SubModuleReplacementDescription(
|
|
||||||
# suffix="mlp.down_proj",
|
|
||||||
# target_module=Linear1D_Row,
|
|
||||||
# ),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -98,7 +86,7 @@ class MixtralPolicy(Policy):
|
|||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="block_sparse_moe",
|
suffix="block_sparse_moe",
|
||||||
target_module=EPMixtralSparseMoeBlock,
|
target_module=EPMixtralSparseMoeBlock,
|
||||||
kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group},
|
kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "moe_tp_group": self.shard_config.moe_tp_group},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
@ -46,6 +46,9 @@ class ShardConfig:
|
|||||||
make_vocab_size_divisible_by: int = 64
|
make_vocab_size_divisible_by: int = 64
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# for moe related
|
||||||
|
moe_dp_group: Optional[ProcessGroup] = None
|
||||||
ep_group: Optional[ProcessGroup] = None
|
ep_group: Optional[ProcessGroup] = None
|
||||||
moe_tp_group: Optional[ProcessGroup] = None
|
moe_tp_group: Optional[ProcessGroup] = None
|
||||||
|
|
||||||
|
@ -18,8 +18,7 @@ NUM_BATCH=4
|
|||||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
||||||
HIDDEN_SIZE_PER_HEAD = 4
|
HIDDEN_SIZE_PER_HEAD = 4
|
||||||
NUM_HEADS=2
|
NUM_HEADS=2
|
||||||
TOP_K = 2
|
TOP_K = 1
|
||||||
|
|
||||||
|
|
||||||
def split_grad(grad, world_size):
|
def split_grad(grad, world_size):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -96,7 +95,6 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
|||||||
# check grad
|
# check grad
|
||||||
name_to_p = {n: p for n, p in ddp_model.named_parameters()}
|
name_to_p = {n: p for n, p in ddp_model.named_parameters()}
|
||||||
for n, p in zero_model.named_parameters():
|
for n, p in zero_model.named_parameters():
|
||||||
print(f"rank {dist.get_rank()} {n}")
|
|
||||||
zero_grad = zero_optimizer.get_param_grad(p)
|
zero_grad = zero_optimizer.get_param_grad(p)
|
||||||
if name_to_p[n].grad is None:
|
if name_to_p[n].grad is None:
|
||||||
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
|
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
|
||||||
@ -124,9 +122,9 @@ def run_dist(rank, world_size, port):
|
|||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_moe_ep_tp(world_size):
|
def test_moe_ep_zero(world_size):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_moe_ep_tp(world_size=4)
|
test_moe_ep_zero(world_size=4)
|
Loading…
Reference in New Issue
Block a user