mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[booster] fixed the torch ddp plugin with the new checkpoint api (#3442)
This commit is contained in:
@@ -13,6 +13,7 @@ from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import save_state_dict
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
@@ -83,7 +84,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str):
|
||||
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
@@ -91,14 +92,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
# as there is communication when get state dict, this must be called on all processes
|
||||
state_dict = model.state_dict(only_rank_0=True)
|
||||
if self.coordinator.is_master():
|
||||
self.save_checkpoint(state_dict, checkpoint)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
# TODO(ver217): optimizer state dict is sharded
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint)
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user