Merge pull request #5372 from hpcaitech/exp/mixtral

This commit is contained in:
Frank Lee
2024-02-08 16:30:05 +08:00
committed by GitHub
33 changed files with 2530 additions and 267 deletions

View File

@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoECheckpintIO
from colossalai.moe import MOE_MANAGER, MoECheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self,
tp_size: int,
pp_size: int,
ep_size: int,
extra_dp_size: int = 1,
precision: str = "fp16",
zero_stage: int = 0,
@@ -181,6 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_communication: bool = True,
use_ep_inside: bool = True,
custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpintIO] = None,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
@@ -188,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=self.real_dp_size,
fixed_ep_size=ep_size,
fixed_pp_size=pp_size,
use_ep_inside=use_ep_inside,
)
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.ep_size = ep_size
self.moe_info = MOE_MANAGER.get_info(0)[1]
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
@@ -200,6 +218,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.checkpoint_io = checkpoint_io
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
@@ -323,7 +342,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
def get_checkpoint_io(self) -> MoECheckpintIO:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def configure(

View File

@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper
from .utils import has_index_file
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
__all__ = ["CheckpointIO"]
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
self.load_unsharded_model(model, checkpoint, strict)
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
path = Path(checkpoint, WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
self.load_unsharded_model(model, checkpoint, strict)
return origin_model

View File

@@ -1,6 +1,7 @@
from .checkpoint import MoECheckpintIO
from .experts import MLPExperts
from .layers import SparseMLP
from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
@@ -14,4 +15,6 @@ __all__ = [
"UniformNoiseGenerator",
"SparseMLP",
"MoECheckpintIO",
"MOE_MANAGER",
"apply_load_balance",
]

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple
import torch
import torch.distributed as dist
@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
return grad, None
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
outputs_shape = list(inputs.shape)
if output_split_sizes is not None:
outputs_shape[0] = sum(output_split_sizes)
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous()
outputs = outputs.contiguous()
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle
class AllToAllUneven(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inputs,
input_split_sizes=None,
output_split_sizes=None,
group=None,
overlap: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes
ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
@staticmethod
def backward(ctx: Any, *grad_outputs):
return (
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
None,
None,
None,
None,
)
def all_to_all_uneven(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
overlap: bool = False,
):
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)

View File

@@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
torch.cuda.empty_cache()
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"index located at {save_index_file}."
)
dist.barrier()
torch.cuda.empty_cache()
# ========================================================
# Abstract methods for optimizer loading/saving implementation
@@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
@@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
id_map[param_id] = param
# Read checkpoint index file.
@@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
# ep extra group
if MOE_MANAGER.parallel == "EP":
# ep param group
if len(optimizer.optim.param_groups) > len(saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
@@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
@@ -400,27 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
# Then shard the loaded optimizer states if using tp/zero.
for pid, state in list(state_dict.items()):
if pid in id_map:
param = id_map[pid]
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif (
hasattr(optimizer, "moe_master_to_working_map")
and id(param) in optimizer.moe_master_to_working_map
):
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
working_param,
current_shape=working_param.shape,
original_shape=original_shape,
device="cpu",
inplace=True,
)
state_dict[pid] = sharded_state
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
@@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
torch.cuda.empty_cache()
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
torch.cuda.empty_cache()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""

View File

@@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self.ep_size = 1
if gated:
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
self.wi_gate = nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
)
)
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))

View File

@@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4,
@@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False,
return_gate_logits: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_experts = num_experts
self.gated = mlp_gated
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel()
self.router_loss = router_loss
self.router_norm = router_norm
# moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
@@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
gate_logits = F.linear(tokens, self.gate_weight)
gate_output = gate_logits.to(torch.float)
# update expert load
if self.enable_load_balance == True:
@@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router
used_capacity, *route_result_list = self.router(
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
inputs=gate_output,
use_kernel=self.enable_kernel,
ep_group=self.ep_group,
use_loss=self.router_loss,
use_norm=self.router_norm,
)
# dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel:
@@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
expert_output = self._ep_process(
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel == "TP":
expert_output = self._tp_process(
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
"Please use Experts build function.")
raise NotImplementedError(
"This kind of communication has not been implemented yet.\n" "Please use Experts build function."
)
if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size)
@@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
return ans
if self.return_gate_logits:
return ans, gate_logits
else:
return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_in = expert_in.unsqueeze(0)
@@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return expert_out
def _ep_process(
self,
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
) -> torch.Tensor:
"""
Expert Parallel
@@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
expert_input = HierarchicalAllToAll.apply(
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
expert_output = HierarchicalAllToAll.apply(
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
)
return expert_output
else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
@@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4
NUM_STAGES = 4
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape)
@@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[:, :, offset:offset + chunk_size, :] = expert_out.data
output[:, :, offset : offset + chunk_size, :] = expert_out.data
offset += chunk_size
expert_out = None
# all2all last output
if _expert_out is not None:
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
expert_out = Capsule(
*AllToAll.apply(_expert_out.data, self.ep_group, True),
)
_expert_out = None
# all2all next input
@@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return output
def _tp_process(
self,
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
) -> torch.Tensor:
"""
without overlap:
@@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4
NUM_STAGES = 4
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
assert (
dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data)

View File

@@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
self._z_loss = None
self.use_kernel = use_kernel
def get_capacity(self, logits_shape):
def get_capacity(self, num_tokens, num_experts, ep_group=None):
if ep_group is not None:
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
dist.all_reduce(num_tokens_tensor, group=ep_group)
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
@@ -150,7 +154,14 @@ class Top1Router(MoeRouter):
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_loss: bool = False,
use_norm: bool = False,
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
@@ -168,7 +179,8 @@ class Top1Router(MoeRouter):
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
num_tokens = inputs.size(0)
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
@@ -207,7 +219,7 @@ class Top1Router(MoeRouter):
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return used_capacity, combine_weights, sec_mask
return used_capacity, combine_weights, sec_mask, probs
class Top2Router(MoeRouter):
@@ -240,7 +252,14 @@ class Top2Router(MoeRouter):
drop_tks=drop_tks,
)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_norm: bool = False,
use_loss: bool = True,
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
@@ -257,8 +276,13 @@ class Top2Router(MoeRouter):
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
if use_norm:
routing_weights, _ = torch.topk(probs, 2, dim=-1)
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
num_tokens = inputs.size(0)
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(probs, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
@@ -270,10 +294,11 @@ class Top2Router(MoeRouter):
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if use_loss:
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0))

View File

@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
return torch.nn.GELU()
elif act == "swiglu":
return SwiGLU
elif act == "silu":
return torch.nn.SiLU()
else:
raise NotImplementedError("Unsupported activation function")

View File

@@ -26,3 +26,5 @@ class MoeParallelInfo:
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
self.ep_rank = self.pg.coordinate(self.ep_axis)
self.dp_rank = self.pg.coordinate(self.dp_axis)

View File

@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# because they have different parallel strategy
# so we need to store them separately in param_groups
# instead of working_groups
moe_params = list()
self.working_moe_params = list()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.moe_extra_dp_pg is None:
# skip moe param
if is_moe_tensor(param):
moe_params.append(param)
self.working_moe_params.append(param)
continue
group_params.append(param)
@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank
param_group["params"] = master_param_current_rank
# if there are moe params, store in additional group in optim
if len(moe_params) > 0:
# if there are moe params, store in addtional group in optim
if len(self.working_moe_params) > 0:
self._sync_master_param = False
param_group = dict()
# create fp32 master param
for key, value in self.optim.param_groups[0].items():
if key != "params":
param_group[key] = value
param_group["params"] = moe_params
self.master_moe_params = []
for param in self.working_moe_params:
self.master_moe_params.append(param.clone().to(torch.float32).detach())
# create mapping from master to working for optimizer io
self.moe_master_to_working_map = {}
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
# add to optim
param_group["params"] = self.master_moe_params
self.optim.param_groups.append(param_group)
# initialize communication stream for
@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update the params in the optimizer
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
# update param for moe ep
# move grad to master param and compute norm
if len(self.working_moe_params) > 0:
moe_grads = []
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
if master_moe_param.grad is not None:
raise RuntimeError("Moe param should not have grad here")
grad = working_moe_param.grad
# no need to copy fp32 grad if master_weights is False
if self._master_weights:
grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
master_moe_param.grad = grad
working_moe_param.grad = None
moe_grads.append(grad)
grad_partition_groups.append(grad)
norm_group = self._compute_grad_norm(gradients=moe_grads)
norm_groups.append(norm_group)
self.optim.param_groups[-1]["params"] = self.master_moe_params
del moe_grads
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# TODO: we should store master param for ep
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.data = param.data.to(torch.float32)
param.grad = param.grad.to(torch.float32)
# update the parameters
self.optim.step()
# release the moe gradm
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.grad = None
param.data = param.data.to(self._dtype)
# release moe grad
if len(self.working_moe_params) > 0:
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.grad = None
working_moe_param.data = (
master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
)
# release the grad
grad_partition_groups = []
@@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
if hasattr(self, "master_moe_params"):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
if hasattr(self, "moe_master_to_working_map"):
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
return self._param_store.master_to_working_param