[moe] support mixtral (#5309)

* [moe] add mixtral block for single expert

* [moe] mixtral block fwd support uneven ep

* [moe] mixtral block bwd support uneven ep

* [moe] add mixtral moe layer

* [moe] simplify replace

* [meo] support save sharded mixtral

* [meo] support load sharded mixtral

* [meo] support save sharded optim

* [meo] integrate moe manager into plug

* [meo] fix optimizer load

* [meo] fix mixtral layer
This commit is contained in:
Hongxin Liu
2024-01-25 15:48:46 +08:00
committed by ver217
parent c904d2ae99
commit da39d21b71
14 changed files with 996 additions and 550 deletions

View File

@@ -2,22 +2,18 @@ import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER, apply_load_balance
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
@@ -153,45 +149,27 @@ def main():
coordinator = DistCoordinator()
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"zero_stage": args.zero_stage,
"checkpoint_io": MixtralMoECheckpointIO,
}
mgr_dict = {}
if args.plugin == "hybrid":
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
custom_policy=MixtralForCausalLMPolicy(),
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
precision=args.precision,
zero_stage=args.zero_stage,
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build Mixtral model
config = MixtralConfig.from_pretrained(args.model_name)
config.use_cache = False
config.num_local_experts = 1
model = MixtralForCausalLM(config)
model.num_experts = 8
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
replace_moe_layer(model, enable_kernel=args.use_kernel)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
model = MixtralForCausalLM.from_pretrained(args.model_name)
coordinator.print_on_master(f"Finish init model")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
@@ -224,7 +202,7 @@ def main():
)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
booster = Booster(plugin=plugin)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
@@ -236,10 +214,7 @@ def main():
coordinator.print_on_master(f"Finish init booster")
# Load ckpt
if args.load_checkpoint is None:
load_model(args.model_name, model, booster, optimizer)
coordinator.print_on_master(f"Finish load checkpoint")
else:
if args.load_checkpoint is not None:
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
coordinator.print_on_master(f"Finish load optimizer")
@@ -286,13 +261,13 @@ def main():
optimizer.zero_grad()
# Apply load balance
if (
args.load_balance
and args.load_balance_interval > 0
and (step + 1) % args.load_balance_interval == 0
):
coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer)
# if (
# args.load_balance
# and args.load_balance_interval > 0
# and (step + 1) % args.load_balance_interval == 0
# ):
# coordinator.print_on_master(f"Apply load balance")
# apply_load_balance(model, optimizer)
# save ckeckpoint
if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")