Merge pull request #5372 from hpcaitech/exp/mixtral

This commit is contained in:
Frank Lee
2024-02-08 16:30:05 +08:00
committed by GitHub
33 changed files with 2530 additions and 267 deletions

View File

@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper
from .utils import has_index_file
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
__all__ = ["CheckpointIO"]
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
self.load_unsharded_model(model, checkpoint, strict)
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
path = Path(checkpoint, WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
self.load_unsharded_model(model, checkpoint, strict)
return origin_model