mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[moe] support optimizer checkpoint (#5015)
* Refactor MoE Manager setup method * unshard optim ckpt * optim io * update transformer version * update requirements * update ckpt * update ckpt * update ckpt * fix engine * fix engine
This commit is contained in:
@@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_activation
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
|
||||
@@ -53,7 +53,8 @@ class MLPExperts(nn.Module):
|
||||
# get expert parallel info
|
||||
if expert_parallel is not None:
|
||||
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
|
||||
num_experts, use_tp=True if expert_parallel == "TP" else False)
|
||||
num_experts, use_tp=True if expert_parallel == "TP" else False
|
||||
)
|
||||
# get settings for different parallel
|
||||
self.ep_size = get_ep_size(self)
|
||||
if expert_parallel == "TP":
|
||||
@@ -87,7 +88,7 @@ class MLPExperts(nn.Module):
|
||||
def reset_parameters(self):
|
||||
# expert param should be different
|
||||
if self.expert_parallel is not None:
|
||||
seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True)
|
||||
seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True)
|
||||
else:
|
||||
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
|
||||
with seed_ctx:
|
||||
@@ -99,10 +100,10 @@ class MLPExperts(nn.Module):
|
||||
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
param_slice: Tuple[slice] = (slice(None),),
|
||||
use_sparse: bool = True,
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
param_slice: Tuple[slice] = (slice(None),),
|
||||
use_sparse: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
forward: hidden_size --> intermediate_size --> hidden_size
|
||||
@@ -129,7 +130,7 @@ class MLPExperts(nn.Module):
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
x_list = []
|
||||
for i in range(e):
|
||||
x_list.append(x[i, :mask[i]])
|
||||
x_list.append(x[i, : mask[i]])
|
||||
x = x_list
|
||||
|
||||
if self.gated:
|
||||
|
Reference in New Issue
Block a user