[moe] support mixtral (#5309)

* [moe] add mixtral block for single expert

* [moe] mixtral block fwd support uneven ep

* [moe] mixtral block bwd support uneven ep

* [moe] add mixtral moe layer

* [moe] simplify replace

* [meo] support save sharded mixtral

* [meo] support load sharded mixtral

* [meo] support save sharded optim

* [meo] integrate moe manager into plug

* [meo] fix optimizer load

* [meo] fix mixtral layer
This commit is contained in:
Hongxin Liu
2024-01-25 15:48:46 +08:00
committed by ver217
parent c904d2ae99
commit da39d21b71
14 changed files with 996 additions and 550 deletions

View File

@@ -3,7 +3,6 @@ import os
from typing import Any, Dict, Tuple, Union
import torch
from huggingface_hub import snapshot_download
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
@@ -15,23 +14,6 @@ def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
@torch.no_grad()
def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
# pytorch ckpt
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
# saved ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
# download
else:
ckpt_path = snapshot_download(ckpt_path)
booster.load_model(model, ckpt_path)
if optimizer is not None:
optimizer.sync_moe_master_param()
optimizer.update_master_params(model)
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
"""
Load file in JSON format
@@ -90,7 +72,7 @@ def load_checkpoint(
"""
# Update booster params states.
load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer)
booster.load_model(model, os.path.join(load_dir, "modeling"))
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))