[moe] support mixtral (#5309)

* [moe] add mixtral block for single expert

* [moe] mixtral block fwd support uneven ep

* [moe] mixtral block bwd support uneven ep

* [moe] add mixtral moe layer

* [moe] simplify replace

* [meo] support save sharded mixtral

* [meo] support load sharded mixtral

* [meo] support save sharded optim

* [meo] integrate moe manager into plug

* [meo] fix optimizer load

* [meo] fix mixtral layer
This commit is contained in:
Hongxin Liu
2024-01-25 15:48:46 +08:00
committed by ver217
parent c904d2ae99
commit da39d21b71
14 changed files with 996 additions and 550 deletions

View File

@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper
from .utils import has_index_file
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
__all__ = ["CheckpointIO"]
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
self.load_unsharded_model(model, checkpoint, strict)
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
path = Path(checkpoint, WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
self.load_unsharded_model(model, checkpoint, strict)
return origin_model