[moe] support optimizer checkpoint (#5015)

* Refactor MoE Manager setup method

* unshard optim ckpt

* optim io

* update transformer version

* update requirements

* update ckpt

* update ckpt

* update ckpt

* fix engine

* fix engine
This commit is contained in:
Xuanlei Zhao
2023-11-08 23:07:03 +08:00
committed by GitHub
parent 67f5331754
commit f71e63b0f3
20 changed files with 738 additions and 150 deletions

View File

@@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
@@ -53,7 +53,8 @@ class MLPExperts(nn.Module):
# get expert parallel info
if expert_parallel is not None:
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
num_experts, use_tp=True if expert_parallel == "TP" else False)
num_experts, use_tp=True if expert_parallel == "TP" else False
)
# get settings for different parallel
self.ep_size = get_ep_size(self)
if expert_parallel == "TP":
@@ -87,7 +88,7 @@ class MLPExperts(nn.Module):
def reset_parameters(self):
# expert param should be different
if self.expert_parallel is not None:
seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True)
seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True)
else:
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
with seed_ctx:
@@ -99,10 +100,10 @@ class MLPExperts(nn.Module):
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
def forward(
self,
x: torch.Tensor,
param_slice: Tuple[slice] = (slice(None),),
use_sparse: bool = True,
self,
x: torch.Tensor,
param_slice: Tuple[slice] = (slice(None),),
use_sparse: bool = True,
) -> torch.Tensor:
"""
forward: hidden_size --> intermediate_size --> hidden_size
@@ -129,7 +130,7 @@ class MLPExperts(nn.Module):
mask = torch.sum(mask, dim=-1)
x_list = []
for i in range(e):
x_list.append(x[i, :mask[i]])
x_list.append(x[i, : mask[i]])
x = x_list
if self.gated: