[moe] support optimizer checkpoint (#5015)

* Refactor MoE Manager setup method

* unshard optim ckpt

* optim io

* update transformer version

* update requirements

* update ckpt

* update ckpt

* update ckpt

* fix engine

* fix engine
This commit is contained in:
Xuanlei Zhao
2023-11-08 23:07:03 +08:00
committed by GitHub
parent 67f5331754
commit f71e63b0f3
20 changed files with 738 additions and 150 deletions

View File

@@ -8,14 +8,13 @@ from colossalai.tensor.moe_tensor.api import get_moe_info
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
class MoeManager(metaclass=SingletonMeta):
class MoEManager(metaclass=SingletonMeta):
"""MoE manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
def __init__(self):
self.parallel = None
self.seed = None
self.mode = None
self.use_ep_inside = None
self.world_size = None
@@ -48,7 +47,6 @@ class MoeManager(metaclass=SingletonMeta):
def setup(
self,
seed: int,
parallel: str = None,
mode: str = "dynamic",
max_ep_size: int = 8,
@@ -73,10 +71,9 @@ class MoeManager(metaclass=SingletonMeta):
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
"""
assert (not self.is_initialized), "MoE distributed context shouldn't be set up again"
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
self.seed = seed + dist.get_rank()
self.parallel = parallel
self.use_ep_inside = use_ep_inside
self.world_size = dist.get_world_size()
@@ -87,10 +84,12 @@ class MoeManager(metaclass=SingletonMeta):
if self.mode == "dynamic":
self.max_ep_size = min(max_ep_size, self.world_size)
else:
assert (fixed_dp_size > 0 and fixed_ep_size > 0
and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0"
assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int)
and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int"
assert (
fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0
), "dp_size, ep_size and pp_size should be greater than 0"
assert (
isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(fixed_pp_size, int)
), "dp_size, ep_size and pp_size should be int"
self.ep_size = fixed_ep_size
self.dp_size = fixed_dp_size
self.pp_size = fixed_pp_size
@@ -112,10 +111,12 @@ class MoeManager(metaclass=SingletonMeta):
"""
if self.mode == "dynamic":
gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater
lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less
assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa.")
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, (
"Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa."
)
dp_size = 1 if gt_flag else self.world_size // num_experts
ep_size = min(self.world_size // dp_size, self.max_ep_size)
dp_size = self.world_size // ep_size
@@ -159,4 +160,4 @@ class MoeManager(metaclass=SingletonMeta):
return self.parallel
MOE_MANAGER = MoeManager()
MOE_MANAGER = MoEManager()