mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
Merge pull request #5372 from hpcaitech/exp/mixtral
This commit is contained in:
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
from .utils import has_index_file
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
|
||||
|
||||
__all__ = ["CheckpointIO"]
|
||||
|
||||
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
|
||||
if index_file_exists:
|
||||
self.load_sharded_model(model, index_file_path, strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
else:
|
||||
path = Path(checkpoint, WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
|
||||
return origin_model
|
||||
|
||||
|
Reference in New Issue
Block a user