mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
Merge pull request #5372 from hpcaitech/exp/mixtral
This commit is contained in:
@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||
)
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MoECheckpintIO
|
||||
from colossalai.moe import MOE_MANAGER, MoECheckpintIO
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
ep_size: int,
|
||||
extra_dp_size: int = 1,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
@@ -181,6 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
overlap_communication: bool = True,
|
||||
use_ep_inside: bool = True,
|
||||
custom_policy: Policy = None,
|
||||
checkpoint_io: Optional[MoECheckpintIO] = None,
|
||||
) -> None:
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
@@ -188,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
||||
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=self.real_dp_size,
|
||||
fixed_ep_size=ep_size,
|
||||
fixed_pp_size=pp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
)
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_info = MOE_MANAGER.get_info(0)[1]
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
@@ -200,6 +218,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.checkpoint_io = checkpoint_io
|
||||
# we change pg mesh to (pp, dp, tp) for better moe performance
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
|
||||
|
||||
@@ -323,7 +342,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
)
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpintIO:
|
||||
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
if self.checkpoint_io is None:
|
||||
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
else:
|
||||
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
return self.checkpoint_io
|
||||
|
||||
def configure(
|
||||
|
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
from .utils import has_index_file
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
|
||||
|
||||
__all__ = ["CheckpointIO"]
|
||||
|
||||
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
|
||||
if index_file_exists:
|
||||
self.load_sharded_model(model, index_file_path, strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
else:
|
||||
path = Path(checkpoint, WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
|
||||
return origin_model
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from .checkpoint import MoECheckpintIO
|
||||
from .experts import MLPExperts
|
||||
from .layers import SparseMLP
|
||||
from .layers import SparseMLP, apply_load_balance
|
||||
from .manager import MOE_MANAGER
|
||||
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||
|
||||
@@ -14,4 +15,6 @@ __all__ = [
|
||||
"UniformNoiseGenerator",
|
||||
"SparseMLP",
|
||||
"MoECheckpintIO",
|
||||
"MOE_MANAGER",
|
||||
"apply_load_balance",
|
||||
]
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
|
||||
if ctx.ep_size != 1:
|
||||
grad = grad / ctx.ep_size
|
||||
return grad, None
|
||||
|
||||
|
||||
def _all_to_all(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
outputs_shape = list(inputs.shape)
|
||||
if output_split_sizes is not None:
|
||||
outputs_shape[0] = sum(output_split_sizes)
|
||||
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
inputs = inputs.contiguous()
|
||||
outputs = outputs.contiguous()
|
||||
handle = dist.all_to_all_single(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
||||
)
|
||||
return outputs, handle
|
||||
|
||||
|
||||
class AllToAllUneven(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
inputs,
|
||||
input_split_sizes=None,
|
||||
output_split_sizes=None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
ctx.input_split_sizes = input_split_sizes
|
||||
ctx.output_split_sizes = output_split_sizes
|
||||
ctx.group = group
|
||||
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs):
|
||||
return (
|
||||
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def all_to_all_uneven(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
):
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
|
@@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
@@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
dist.barrier()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for optimizer loading/saving implementation
|
||||
@@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
|
||||
):
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
@@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
@@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
||||
updated_groups.append(new_pg)
|
||||
# ep extra group
|
||||
if MOE_MANAGER.parallel == "EP":
|
||||
# ep param group
|
||||
if len(optimizer.optim.param_groups) > len(saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1][
|
||||
"params"
|
||||
] # Only keep the parameters kept by current pipeline stage.
|
||||
for param in new_pg["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
@@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
@@ -400,27 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for pid, state in list(state_dict.items()):
|
||||
if pid in id_map:
|
||||
param = id_map[pid]
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif (
|
||||
hasattr(optimizer, "moe_master_to_working_map")
|
||||
and id(param) in optimizer.moe_master_to_working_map
|
||||
):
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
working_param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device="cpu",
|
||||
inplace=True,
|
||||
)
|
||||
state_dict[pid] = sharded_state
|
||||
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
device = param.device
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True,
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
@@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
@@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
@@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
|
@@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
|
||||
self.ep_size = 1
|
||||
|
||||
if gated:
|
||||
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
|
||||
self.wi_gate = nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
|
||||
)
|
||||
)
|
||||
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
else:
|
||||
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
|
@@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
router_top_k: int = 1,
|
||||
router_loss: bool = True,
|
||||
router_norm: bool = False,
|
||||
router_capacity_factor_train: float = 1.25,
|
||||
router_capacity_factor_eval: float = 2.0,
|
||||
router_min_capacity: int = 4,
|
||||
@@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
|
||||
enable_kernel: bool = False,
|
||||
enable_comm_overlap: bool = False,
|
||||
enable_hierarchical_comm: bool = False,
|
||||
return_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_experts = num_experts
|
||||
self.gated = mlp_gated
|
||||
self.return_gate_logits = return_gate_logits
|
||||
self.enable_kernel = enable_kernel
|
||||
self.enable_comm_overlap = enable_comm_overlap
|
||||
self.expert_parallel = MOE_MANAGER.get_parallel()
|
||||
self.router_loss = router_loss
|
||||
self.router_norm = router_norm
|
||||
|
||||
# moe router
|
||||
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
|
||||
@@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
|
||||
tokens = inputs.reshape(-1, self.hidden_size)
|
||||
|
||||
# the data type of the inputs in the gating should be fp32
|
||||
fp32_input = tokens.to(torch.float)
|
||||
fp32_weight = self.gate_weight.to(torch.float)
|
||||
gate_output = F.linear(fp32_input, fp32_weight)
|
||||
gate_logits = F.linear(tokens, self.gate_weight)
|
||||
gate_output = gate_logits.to(torch.float)
|
||||
|
||||
# update expert load
|
||||
if self.enable_load_balance == True:
|
||||
@@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
|
||||
|
||||
# the result from the router
|
||||
used_capacity, *route_result_list = self.router(
|
||||
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
|
||||
inputs=gate_output,
|
||||
use_kernel=self.enable_kernel,
|
||||
ep_group=self.ep_group,
|
||||
use_loss=self.router_loss,
|
||||
use_norm=self.router_norm,
|
||||
)
|
||||
|
||||
# dispatch_data: (num_experts, capacity, hidden_size)
|
||||
if self.enable_kernel:
|
||||
@@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
|
||||
|
||||
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
||||
if self.expert_parallel == "EP":
|
||||
expert_output = self._ep_process(
|
||||
dispatch_data,
|
||||
used_capacity,
|
||||
overlap=self.enable_comm_overlap
|
||||
)
|
||||
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel == "TP":
|
||||
expert_output = self._tp_process(
|
||||
dispatch_data,
|
||||
used_capacity,
|
||||
overlap=self.enable_comm_overlap
|
||||
)
|
||||
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel is None:
|
||||
expert_output = self._local_process(dispatch_data)
|
||||
else:
|
||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
|
||||
"Please use Experts build function.")
|
||||
raise NotImplementedError(
|
||||
"This kind of communication has not been implemented yet.\n" "Please use Experts build function."
|
||||
)
|
||||
|
||||
if self.enable_kernel:
|
||||
expert_output = expert_output.reshape(-1, self.hidden_size)
|
||||
@@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
|
||||
ans = torch.matmul(combine_weights, expert_output)
|
||||
|
||||
ans = ans.reshape(inputs.shape)
|
||||
return ans
|
||||
|
||||
if self.return_gate_logits:
|
||||
return ans, gate_logits
|
||||
else:
|
||||
return ans
|
||||
|
||||
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
|
||||
expert_in = expert_in.unsqueeze(0)
|
||||
@@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
|
||||
return expert_out
|
||||
|
||||
def _ep_process(
|
||||
self,
|
||||
dispatch_data: torch.Tensor,
|
||||
used_capacity: torch.Tensor,
|
||||
overlap: bool = False
|
||||
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expert Parallel
|
||||
@@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
|
||||
"""
|
||||
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
||||
if self.ep_hierarchical_group is not None:
|
||||
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
||||
expert_input = HierarchicalAllToAll.apply(
|
||||
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
|
||||
)
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
||||
expert_output = HierarchicalAllToAll.apply(
|
||||
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
|
||||
)
|
||||
return expert_output
|
||||
else:
|
||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
||||
@@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
|
||||
NUM_CHUNK = 4
|
||||
NUM_STAGES = 4
|
||||
|
||||
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
|
||||
assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
|
||||
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
|
||||
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
dispatch_data = dispatch_data.reshape(*input_shape)
|
||||
@@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
|
||||
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
||||
if expert_out is not None:
|
||||
expert_out.handle.wait()
|
||||
output[:, :, offset:offset + chunk_size, :] = expert_out.data
|
||||
output[:, :, offset : offset + chunk_size, :] = expert_out.data
|
||||
offset += chunk_size
|
||||
expert_out = None
|
||||
|
||||
# all2all last output
|
||||
if _expert_out is not None:
|
||||
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
|
||||
expert_out = Capsule(
|
||||
*AllToAll.apply(_expert_out.data, self.ep_group, True),
|
||||
)
|
||||
_expert_out = None
|
||||
|
||||
# all2all next input
|
||||
@@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
|
||||
return output
|
||||
|
||||
def _tp_process(
|
||||
self,
|
||||
dispatch_data: torch.Tensor,
|
||||
used_capacity: torch.Tensor,
|
||||
overlap: bool = False
|
||||
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
without overlap:
|
||||
@@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
|
||||
NUM_CHUNK = 4
|
||||
NUM_STAGES = 4
|
||||
|
||||
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
|
||||
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||
assert (
|
||||
dispatch_data.shape[0] % NUM_CHUNK == 0
|
||||
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
|
||||
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
|
||||
output = torch.empty_like(dispatch_data)
|
||||
|
@@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
|
||||
self._z_loss = None
|
||||
self.use_kernel = use_kernel
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
def get_capacity(self, num_tokens, num_experts, ep_group=None):
|
||||
if ep_group is not None:
|
||||
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
|
||||
dist.all_reduce(num_tokens_tensor, group=ep_group)
|
||||
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
@@ -150,7 +154,14 @@ class Top1Router(MoeRouter):
|
||||
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
|
||||
).rsample
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
use_kernel: bool = False,
|
||||
ep_group: Optional[ProcessGroup] = None,
|
||||
use_loss: bool = False,
|
||||
use_norm: bool = False,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
@@ -168,7 +179,8 @@ class Top1Router(MoeRouter):
|
||||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
num_experts = probs.size(-1)
|
||||
capacity = self.get_capacity(inputs.shape)
|
||||
num_tokens = inputs.size(0)
|
||||
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
||||
|
||||
top1_idx = torch.argmax(inputs, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
@@ -207,7 +219,7 @@ class Top1Router(MoeRouter):
|
||||
weight = mask * probs.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return used_capacity, combine_weights, sec_mask
|
||||
return used_capacity, combine_weights, sec_mask, probs
|
||||
|
||||
|
||||
class Top2Router(MoeRouter):
|
||||
@@ -240,7 +252,14 @@ class Top2Router(MoeRouter):
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
use_kernel: bool = False,
|
||||
ep_group: Optional[ProcessGroup] = None,
|
||||
use_norm: bool = False,
|
||||
use_loss: bool = True,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
@@ -257,8 +276,13 @@ class Top2Router(MoeRouter):
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
if use_norm:
|
||||
routing_weights, _ = torch.topk(probs, 2, dim=-1)
|
||||
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
num_experts = probs.size(-1)
|
||||
capacity = self.get_capacity(inputs.shape)
|
||||
num_tokens = inputs.size(0)
|
||||
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
||||
|
||||
top1_idx = torch.argmax(probs, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
@@ -270,10 +294,11 @@ class Top2Router(MoeRouter):
|
||||
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
||||
|
||||
# calculate loss
|
||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
if use_loss:
|
||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
|
||||
if not self.training and not self.drop_tks and ep_group is not None:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
|
@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
|
||||
return torch.nn.GELU()
|
||||
elif act == "swiglu":
|
||||
return SwiGLU
|
||||
elif act == "silu":
|
||||
return torch.nn.SiLU()
|
||||
else:
|
||||
raise NotImplementedError("Unsupported activation function")
|
||||
|
||||
|
@@ -26,3 +26,5 @@ class MoeParallelInfo:
|
||||
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
|
||||
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
|
||||
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
|
||||
self.ep_rank = self.pg.coordinate(self.ep_axis)
|
||||
self.dp_rank = self.pg.coordinate(self.dp_axis)
|
||||
|
@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# because they have different parallel strategy
|
||||
# so we need to store them separately in param_groups
|
||||
# instead of working_groups
|
||||
moe_params = list()
|
||||
self.working_moe_params = list()
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if self.moe_extra_dp_pg is None:
|
||||
# skip moe param
|
||||
if is_moe_tensor(param):
|
||||
moe_params.append(param)
|
||||
self.working_moe_params.append(param)
|
||||
continue
|
||||
group_params.append(param)
|
||||
|
||||
@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# managed by this data parallel rank
|
||||
param_group["params"] = master_param_current_rank
|
||||
|
||||
# if there are moe params, store in additional group in optim
|
||||
if len(moe_params) > 0:
|
||||
# if there are moe params, store in addtional group in optim
|
||||
if len(self.working_moe_params) > 0:
|
||||
self._sync_master_param = False
|
||||
param_group = dict()
|
||||
# create fp32 master param
|
||||
for key, value in self.optim.param_groups[0].items():
|
||||
if key != "params":
|
||||
param_group[key] = value
|
||||
param_group["params"] = moe_params
|
||||
self.master_moe_params = []
|
||||
for param in self.working_moe_params:
|
||||
self.master_moe_params.append(param.clone().to(torch.float32).detach())
|
||||
# create mapping from master to working for optimizer io
|
||||
self.moe_master_to_working_map = {}
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
|
||||
# add to optim
|
||||
param_group["params"] = self.master_moe_params
|
||||
self.optim.param_groups.append(param_group)
|
||||
|
||||
# initialize communication stream for
|
||||
@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# update the params in the optimizer
|
||||
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
|
||||
|
||||
# update param for moe ep
|
||||
# move grad to master param and compute norm
|
||||
if len(self.working_moe_params) > 0:
|
||||
moe_grads = []
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
if master_moe_param.grad is not None:
|
||||
raise RuntimeError("Moe param should not have grad here")
|
||||
grad = working_moe_param.grad
|
||||
# no need to copy fp32 grad if master_weights is False
|
||||
if self._master_weights:
|
||||
grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
|
||||
master_moe_param.grad = grad
|
||||
working_moe_param.grad = None
|
||||
moe_grads.append(grad)
|
||||
grad_partition_groups.append(grad)
|
||||
norm_group = self._compute_grad_norm(gradients=moe_grads)
|
||||
norm_groups.append(norm_group)
|
||||
self.optim.param_groups[-1]["params"] = self.master_moe_params
|
||||
del moe_grads
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||
|
||||
# TODO: we should store master param for ep
|
||||
if len(self.param_groups) > len(self._working_param_groups):
|
||||
for param in self.param_groups[-1]["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
param.grad = param.grad.to(torch.float32)
|
||||
|
||||
# update the parameters
|
||||
self.optim.step()
|
||||
|
||||
# release the moe gradm
|
||||
if len(self.param_groups) > len(self._working_param_groups):
|
||||
for param in self.param_groups[-1]["params"]:
|
||||
param.grad = None
|
||||
param.data = param.data.to(self._dtype)
|
||||
# release moe grad
|
||||
if len(self.working_moe_params) > 0:
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.grad = None
|
||||
working_moe_param.data = (
|
||||
master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
|
||||
)
|
||||
|
||||
# release the grad
|
||||
grad_partition_groups = []
|
||||
@@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
|
||||
else:
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
if hasattr(self, "master_moe_params"):
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.copy_(working_moe_param)
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.working_to_master_param
|
||||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
if hasattr(self, "moe_master_to_working_map"):
|
||||
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
|
||||
return self._param_store.master_to_working_param
|
||||
|
Reference in New Issue
Block a user