mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[moe] support optimizer checkpoint (#5015)
* Refactor MoE Manager setup method * unshard optim ckpt * optim io * update transformer version * update requirements * update ckpt * update ckpt * update ckpt * fix engine * fix engine
This commit is contained in:
@@ -8,14 +8,13 @@ from colossalai.tensor.moe_tensor.api import get_moe_info
|
||||
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
||||
|
||||
|
||||
class MoeManager(metaclass=SingletonMeta):
|
||||
class MoEManager(metaclass=SingletonMeta):
|
||||
"""MoE manager. This class manages different
|
||||
parallel groups in MoE context and MoE loss in training.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.parallel = None
|
||||
self.seed = None
|
||||
self.mode = None
|
||||
self.use_ep_inside = None
|
||||
self.world_size = None
|
||||
@@ -48,7 +47,6 @@ class MoeManager(metaclass=SingletonMeta):
|
||||
|
||||
def setup(
|
||||
self,
|
||||
seed: int,
|
||||
parallel: str = None,
|
||||
mode: str = "dynamic",
|
||||
max_ep_size: int = 8,
|
||||
@@ -73,10 +71,9 @@ class MoeManager(metaclass=SingletonMeta):
|
||||
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
|
||||
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
|
||||
"""
|
||||
assert (not self.is_initialized), "MoE distributed context shouldn't be set up again"
|
||||
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
|
||||
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
|
||||
|
||||
self.seed = seed + dist.get_rank()
|
||||
self.parallel = parallel
|
||||
self.use_ep_inside = use_ep_inside
|
||||
self.world_size = dist.get_world_size()
|
||||
@@ -87,10 +84,12 @@ class MoeManager(metaclass=SingletonMeta):
|
||||
if self.mode == "dynamic":
|
||||
self.max_ep_size = min(max_ep_size, self.world_size)
|
||||
else:
|
||||
assert (fixed_dp_size > 0 and fixed_ep_size > 0
|
||||
and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0"
|
||||
assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int)
|
||||
and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int"
|
||||
assert (
|
||||
fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0
|
||||
), "dp_size, ep_size and pp_size should be greater than 0"
|
||||
assert (
|
||||
isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(fixed_pp_size, int)
|
||||
), "dp_size, ep_size and pp_size should be int"
|
||||
self.ep_size = fixed_ep_size
|
||||
self.dp_size = fixed_dp_size
|
||||
self.pp_size = fixed_pp_size
|
||||
@@ -112,10 +111,12 @@ class MoeManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
|
||||
if self.mode == "dynamic":
|
||||
gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater
|
||||
lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less
|
||||
assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number"
|
||||
" is not a multiple of ep size or vice versa.")
|
||||
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
|
||||
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
|
||||
assert gt_flag or lt_flag, (
|
||||
"Automatic experts placement dose not not support expert number"
|
||||
" is not a multiple of ep size or vice versa."
|
||||
)
|
||||
dp_size = 1 if gt_flag else self.world_size // num_experts
|
||||
ep_size = min(self.world_size // dp_size, self.max_ep_size)
|
||||
dp_size = self.world_size // ep_size
|
||||
@@ -159,4 +160,4 @@ class MoeManager(metaclass=SingletonMeta):
|
||||
return self.parallel
|
||||
|
||||
|
||||
MOE_MANAGER = MoeManager()
|
||||
MOE_MANAGER = MoEManager()
|
||||
|
Reference in New Issue
Block a user