[moe] support mixtral (#5309)

* [moe] add mixtral block for single expert

* [moe] mixtral block fwd support uneven ep

* [moe] mixtral block bwd support uneven ep

* [moe] add mixtral moe layer

* [moe] simplify replace

* [meo] support save sharded mixtral

* [meo] support load sharded mixtral

* [meo] support save sharded optim

* [meo] integrate moe manager into plug

* [meo] fix optimizer load

* [meo] fix mixtral layer
This commit is contained in:
Hongxin Liu
2024-01-25 15:48:46 +08:00
committed by ver217
parent c904d2ae99
commit da39d21b71
14 changed files with 996 additions and 550 deletions

View File

@@ -666,10 +666,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
def sync_moe_master_param(self):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@@ -915,9 +911,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.master_to_working_param
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}