mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-01 03:45:27 +00:00
[booster] fixed the torch ddp plugin with the new checkpoint api (#3442)
This commit is contained in:
parent
8f740deb53
commit
7d8d825681
@ -13,6 +13,7 @@ from torch.utils.data import DataLoader
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||||
|
from colossalai.checkpoint_io.utils import save_state_dict
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
@ -83,7 +84,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||||
|
|
||||||
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str):
|
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
@ -91,14 +92,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
# as there is communication when get state dict, this must be called on all processes
|
# as there is communication when get state dict, this must be called on all processes
|
||||||
state_dict = model.state_dict(only_rank_0=True)
|
state_dict = model.state_dict(only_rank_0=True)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
self.save_checkpoint(state_dict, checkpoint)
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
# TODO(ver217): optimizer state dict is sharded
|
# TODO(ver217): optimizer state dict is sharded
|
||||||
super().save_unsharded_optimizer(optimizer, checkpoint)
|
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||||
|
|
||||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
"""
|
"""
|
||||||
|
@ -33,20 +33,20 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||||
|
|
||||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_unsharded_model(model, checkpoint)
|
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_unsharded_optimizer(optimizer, checkpoint)
|
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||||
|
|
||||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
"""
|
"""
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
## 🚀 Quick Start
|
## 🚀 Quick Start
|
||||||
|
|
||||||
This example provides a training script and and evaluation script. The training script provides a an example of training ResNet on CIFAR10 dataset from scratch.
|
This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch.
|
||||||
|
|
||||||
- Training Arguments
|
- Training Arguments
|
||||||
- `-r, `--resume`: resume from checkpoint file path
|
- `-r`, `--resume`: resume from checkpoint file path
|
||||||
- `-c`, `--checkpoint`: the folder to save checkpoints
|
- `-c`, `--checkpoint`: the folder to save checkpoints
|
||||||
- `-i`, `--interval`: epoch interval to save checkpoints
|
- `-i`, `--interval`: epoch interval to save checkpoints
|
||||||
- `-f`, `--fp16`: use fp16
|
- `-f`, `--fp16`: use fp16
|
||||||
@ -41,4 +41,4 @@ Expected accuracy performance will be:
|
|||||||
| --------- | ------------------------ | --------------------- | --------------------- |
|
| --------- | ------------------------ | --------------------- | --------------------- |
|
||||||
| ResNet-18 | 85.85% | 85.03% | 85.12% |
|
| ResNet-18 | 85.85% | 85.03% | 85.12% |
|
||||||
|
|
||||||
**Note: the baseline is a adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
|
**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
|
||||||
|
Loading…
Reference in New Issue
Block a user