mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[moe] init mixtral impl
This commit is contained in:
@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# because they have different parallel strategy
|
||||
# so we need to store them separately in param_groups
|
||||
# instead of working_groups
|
||||
moe_params = list()
|
||||
self.working_moe_params = list()
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if self.moe_extra_dp_pg is None:
|
||||
# skip moe param
|
||||
if is_moe_tensor(param):
|
||||
moe_params.append(param)
|
||||
self.working_moe_params.append(param)
|
||||
continue
|
||||
group_params.append(param)
|
||||
|
||||
@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# managed by this data parallel rank
|
||||
param_group["params"] = master_param_current_rank
|
||||
|
||||
# if there are moe params, store in additional group in optim
|
||||
if len(moe_params) > 0:
|
||||
# if there are moe params, store in addtional group in optim
|
||||
if len(self.working_moe_params) > 0:
|
||||
self._sync_master_param = False
|
||||
param_group = dict()
|
||||
# create fp32 master param
|
||||
for key, value in self.optim.param_groups[0].items():
|
||||
if key != "params":
|
||||
param_group[key] = value
|
||||
param_group["params"] = moe_params
|
||||
self.master_moe_params = []
|
||||
for param in self.working_moe_params:
|
||||
self.master_moe_params.append(param.clone().to(torch.float32).detach())
|
||||
# create mapping from master to working for optimizer io
|
||||
self.moe_master_to_working_map = {}
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
|
||||
# add to optim
|
||||
param_group["params"] = self.master_moe_params
|
||||
self.optim.param_groups.append(param_group)
|
||||
|
||||
# initialize communication stream for
|
||||
@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# update the params in the optimizer
|
||||
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
|
||||
|
||||
# update param for moe ep
|
||||
# move grad to master param and compute norm
|
||||
if len(self.working_moe_params) > 0:
|
||||
moe_grads = []
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
if master_moe_param.grad is not None:
|
||||
raise RuntimeError("Moe param should not have grad here")
|
||||
grad = working_moe_param.grad
|
||||
# no need to copy fp32 grad if master_weights is False
|
||||
if self._master_weights:
|
||||
grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
|
||||
master_moe_param.grad = grad
|
||||
working_moe_param.grad = None
|
||||
moe_grads.append(grad)
|
||||
grad_partition_groups.append(grad)
|
||||
norm_group = self._compute_grad_norm(gradients=moe_grads)
|
||||
norm_groups.append(norm_group)
|
||||
self.optim.param_groups[-1]["params"] = self.master_moe_params
|
||||
del moe_grads
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||
|
||||
# TODO: we should store master param for ep
|
||||
if len(self.param_groups) > len(self._working_param_groups):
|
||||
for param in self.param_groups[-1]["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
param.grad = param.grad.to(torch.float32)
|
||||
|
||||
# update the parameters
|
||||
self.optim.step()
|
||||
|
||||
# release the moe gradm
|
||||
if len(self.param_groups) > len(self._working_param_groups):
|
||||
for param in self.param_groups[-1]["params"]:
|
||||
param.grad = None
|
||||
param.data = param.data.to(self._dtype)
|
||||
# release moe grad
|
||||
if len(self.working_moe_params) > 0:
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.grad = None
|
||||
working_moe_param.data = (
|
||||
master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
|
||||
)
|
||||
|
||||
# release the grad
|
||||
grad_partition_groups = []
|
||||
@@ -640,6 +666,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
def sync_moe_master_param(self):
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
Reference in New Issue
Block a user