mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[moe] support mixtral (#5309)
* [moe] add mixtral block for single expert * [moe] mixtral block fwd support uneven ep * [moe] mixtral block bwd support uneven ep * [moe] add mixtral moe layer * [moe] simplify replace * [meo] support save sharded mixtral * [meo] support load sharded mixtral * [meo] support save sharded optim * [meo] integrate moe manager into plug * [meo] fix optimizer load * [meo] fix mixtral layer
This commit is contained in:
@@ -3,7 +3,6 @@ import os
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
@@ -15,23 +14,6 @@ def move_to_cuda(batch, device):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
|
||||
# pytorch ckpt
|
||||
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
|
||||
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
|
||||
# saved ckpt
|
||||
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
|
||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
||||
# download
|
||||
else:
|
||||
ckpt_path = snapshot_download(ckpt_path)
|
||||
booster.load_model(model, ckpt_path)
|
||||
if optimizer is not None:
|
||||
optimizer.sync_moe_master_param()
|
||||
optimizer.update_master_params(model)
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load file in JSON format
|
||||
@@ -90,7 +72,7 @@ def load_checkpoint(
|
||||
"""
|
||||
|
||||
# Update booster params states.
|
||||
load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer)
|
||||
booster.load_model(model, os.path.join(load_dir, "modeling"))
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
|
||||
|
Reference in New Issue
Block a user