mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-03 05:58:09 +00:00
[booster] implemented the torch ddd + resnet example (#3232)
* [booster] implemented the torch ddd + resnet example * polish code
This commit is contained in:
parent
1a229045af
commit
73d3e4d309
colossalai
booster
checkpoint_io
cluster
interface
examples/tutorial/new_api
tests
@ -1,4 +1,3 @@
|
|||||||
from .accelerator import Accelerator
|
from .accelerator import Accelerator
|
||||||
from .booster import Booster
|
from .booster import Booster
|
||||||
from .environment_table import EnvironmentTable
|
|
||||||
from .plugin import Plugin
|
from .plugin import Plugin
|
||||||
|
@ -8,6 +8,8 @@ from torch.optim import Optimizer
|
|||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||||
|
|
||||||
from .accelerator import Accelerator
|
from .accelerator import Accelerator
|
||||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||||
from .plugin import Plugin
|
from .plugin import Plugin
|
||||||
@ -61,19 +63,21 @@ class Booster:
|
|||||||
self.plugin = plugin
|
self.plugin = plugin
|
||||||
|
|
||||||
# set accelerator
|
# set accelerator
|
||||||
if self.plugin and self.plugin.control_device:
|
if self.plugin and self.plugin.control_device():
|
||||||
self.accelerator = None
|
self.accelerator = None
|
||||||
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
|
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
|
||||||
else:
|
else:
|
||||||
self.accelerator = Accelerator(device)
|
self.accelerator = Accelerator(device)
|
||||||
|
|
||||||
# set precision
|
# set precision
|
||||||
if mixed_precision is None or (self.plugin and self.plugin.control_precision):
|
if self.plugin and self.plugin.control_precision():
|
||||||
self.mixed_precision = None
|
|
||||||
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
|
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
|
||||||
|
self.mixed_precision = None
|
||||||
|
elif mixed_precision is None:
|
||||||
|
self.mixed_precision = None
|
||||||
else:
|
else:
|
||||||
# validate and set precision
|
# validate and set precision
|
||||||
if isinstance(MixedPrecision, str):
|
if isinstance(mixed_precision, str):
|
||||||
# the user will take the default arguments for amp training
|
# the user will take the default arguments for amp training
|
||||||
self.mixed_precision = mixed_precision_factory(mixed_precision)
|
self.mixed_precision = mixed_precision_factory(mixed_precision)
|
||||||
elif isinstance(mixed_precision, MixedPrecision):
|
elif isinstance(mixed_precision, MixedPrecision):
|
||||||
@ -84,6 +88,11 @@ class Booster:
|
|||||||
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
|
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.plugin is not None and self.plugin.control_checkpoint_io():
|
||||||
|
self.checkpoint_io = self.plugin.get_checkpoint_io()
|
||||||
|
else:
|
||||||
|
self.checkpoint_io = GeneralCheckpointIO()
|
||||||
|
|
||||||
def boost(
|
def boost(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -109,12 +118,13 @@ class Booster:
|
|||||||
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
|
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
|
||||||
model, optimizer, criterion, dataloader, lr_scheduler)
|
model, optimizer, criterion, dataloader, lr_scheduler)
|
||||||
|
|
||||||
if self.plugin and not self.plugin.control_device:
|
if self.plugin and not self.plugin.control_device():
|
||||||
# transform model for accelerator
|
# transform model for accelerator
|
||||||
model = self.accelerator.configure(model)
|
model = self.accelerator.configure(model)
|
||||||
|
|
||||||
if self.mixed_precision and self.plugin and not self.plugin.control_precision:
|
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
|
||||||
# transform model for mixed precision
|
# transform model for mixed precision
|
||||||
|
# when mixed_precision is specified and the plugin is not given or does not control the precision
|
||||||
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
@ -140,18 +150,25 @@ class Booster:
|
|||||||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||||
return self.plugin.no_sync(model)
|
return self.plugin.no_sync(model)
|
||||||
|
|
||||||
def save(self,
|
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||||
obj: Union[nn.Module, Optimizer, LRScheduler],
|
self.checkpoint_io.load_model(model, checkpoint, strict)
|
||||||
path_like: str,
|
|
||||||
plan: str = 'torch',
|
|
||||||
**kwargs) -> None:
|
|
||||||
# TODO: implement this method
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load(self,
|
def save_model(self,
|
||||||
obj: Union[nn.Module, Optimizer, LRScheduler],
|
model: nn.Module,
|
||||||
path_like: str,
|
checkpoint: str,
|
||||||
plan: str = 'torch',
|
prefix: str = None,
|
||||||
**kwargs) -> None:
|
shard: bool = False,
|
||||||
# TODO: implement this method
|
size_per_shard: int = 1024):
|
||||||
pass
|
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
|
||||||
|
|
||||||
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||||
|
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
|
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
||||||
|
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
|
||||||
|
|
||||||
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
|
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
|
||||||
|
|
||||||
|
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
|
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
__all__ = ['EnvironmentTable']
|
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentTable:
|
|
||||||
|
|
||||||
def __init__(self, intra_op_world_sizes: List[int]):
|
|
||||||
# TODO: implement this method
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_master(self) -> bool:
|
|
||||||
# TODO: implement this method
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: implement more utility methods as given in
|
|
||||||
# https://github.com/hpcaitech/ColossalAI/issues/3051
|
|
@ -1,3 +0,0 @@
|
|||||||
from .optimizer import OptimizerWrapper
|
|
||||||
|
|
||||||
__all__ = ['OptimizerWrapper']
|
|
@ -5,7 +5,8 @@ import torch.nn as nn
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from ..interface import OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
|
||||||
from .mixed_precision_base import MixedPrecision
|
from .mixed_precision_base import MixedPrecision
|
||||||
|
|
||||||
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
|
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
|
||||||
@ -45,7 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
|||||||
scaled_loss.backward(*args, **kwargs)
|
scaled_loss.backward(*args, **kwargs)
|
||||||
|
|
||||||
def step(self, *args, **kwargs) -> Optional[float]:
|
def step(self, *args, **kwargs) -> Optional[float]:
|
||||||
return self.scaler.step(self.optim, *args, **kwargs)
|
out = self.scaler.step(self.optim, *args, **kwargs)
|
||||||
|
self.scaler.update()
|
||||||
|
return out
|
||||||
|
|
||||||
def scale_loss(self, loss: Tensor) -> Tensor:
|
def scale_loss(self, loss: Tensor) -> Tensor:
|
||||||
return self.scaler.scale(loss)
|
return self.scaler.scale(loss)
|
||||||
@ -67,7 +70,7 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
|||||||
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TorchAMPModule(nn.Module):
|
class TorchAMPModule(ModelWrapper):
|
||||||
"""
|
"""
|
||||||
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
|
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
|
||||||
|
|
||||||
@ -76,8 +79,7 @@ class TorchAMPModule(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, module: nn.Module):
|
def __init__(self, module: nn.Module):
|
||||||
super().__init__()
|
super().__init__(module)
|
||||||
self.module = module
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
with torch.cuda.amp.autocast():
|
with torch.cuda.amp.autocast():
|
||||||
|
@ -4,7 +4,7 @@ from typing import Callable, Tuple
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from ..interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
|
|
||||||
|
|
||||||
class MixedPrecision(ABC):
|
class MixedPrecision(ABC):
|
||||||
|
@ -6,34 +6,30 @@ from torch.optim import Optimizer
|
|||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.booster.interface import OptimizerWrapper
|
from colossalai.checkpoint_io import CheckpointIO
|
||||||
|
from colossalai.interface import OptimizerWrapper
|
||||||
|
|
||||||
__all__ = ['Plugin']
|
__all__ = ['Plugin']
|
||||||
|
|
||||||
|
|
||||||
class Plugin(ABC):
|
class Plugin(ABC):
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def supported_devices(self) -> List[str]:
|
def supported_devices(self) -> List[str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def supported_precisions(self) -> List[str]:
|
def supported_precisions(self) -> List[str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def control_precision(self) -> bool:
|
def control_precision(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def control_device(self) -> bool:
|
def control_device(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
pass
|
pass
|
||||||
@ -49,3 +45,17 @@ class Plugin(ABC):
|
|||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||||
# implement this method
|
# implement this method
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def control_checkpoint_io(self) -> bool:
|
||||||
|
"""
|
||||||
|
Whether the plugin controls the checkpoint io
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
|
"""
|
||||||
|
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
@ -11,13 +11,61 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from colossalai.booster.interface import OptimizerWrapper
|
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
|
||||||
from .plugin_base import Plugin
|
from .plugin_base import Plugin
|
||||||
|
|
||||||
__all__ = ['TorchDDPPlugin']
|
__all__ = ['TorchDDPPlugin']
|
||||||
|
|
||||||
|
|
||||||
|
class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||||
|
"""
|
||||||
|
Load model from checkpoint with automatic unwrapping.
|
||||||
|
"""
|
||||||
|
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||||
|
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||||
|
|
||||||
|
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
|
||||||
|
"""
|
||||||
|
Save model to checkpoint but only on master process.
|
||||||
|
"""
|
||||||
|
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
super().save_unsharded_model(model, checkpoint)
|
||||||
|
|
||||||
|
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||||
|
"""
|
||||||
|
Save optimizer to checkpoint but only on master process.
|
||||||
|
"""
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
super().save_unsharded_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
|
"""
|
||||||
|
Save model to checkpoint but only on master process.
|
||||||
|
"""
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchDDPModel(ModelWrapper):
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
|
||||||
|
super().__init__(module)
|
||||||
|
self.module = DDP(module, *args, **kwargs)
|
||||||
|
|
||||||
|
def unwrap(self):
|
||||||
|
return self.module.module
|
||||||
|
|
||||||
|
|
||||||
class TorchDDPPlugin(Plugin):
|
class TorchDDPPlugin(Plugin):
|
||||||
"""
|
"""
|
||||||
Plugin for PyTorch DDP.
|
Plugin for PyTorch DDP.
|
||||||
@ -138,10 +186,19 @@ class TorchDDPPlugin(Plugin):
|
|||||||
# cast model to cuda
|
# cast model to cuda
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
|
# convert model to sync bn
|
||||||
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
|
||||||
|
|
||||||
# wrap the model with PyTorch DDP
|
# wrap the model with PyTorch DDP
|
||||||
model = DDP(model, **self.ddp_kwargs)
|
model = TorchDDPModel(model, **self.ddp_kwargs)
|
||||||
|
|
||||||
if not isinstance(optimizer, OptimizerWrapper):
|
if not isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer = OptimizerWrapper(optimizer)
|
optimizer = OptimizerWrapper(optimizer)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
||||||
|
def control_checkpoint_io(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
|
return TorchDDPCheckpointIO()
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
|
||||||
|
from colossalai.interface import ModelWrapper
|
||||||
|
|
||||||
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
|
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
|
||||||
|
|
||||||
|
|
||||||
@ -37,15 +39,15 @@ class CheckpointIO(ABC):
|
|||||||
>>>
|
>>>
|
||||||
>>> # save optimizer to checkpoint
|
>>> # save optimizer to checkpoint
|
||||||
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
# Abstract methods for implementation
|
# Public methods
|
||||||
# ======================================
|
# ======================================
|
||||||
|
def load_model(self,
|
||||||
@abstractmethod
|
model: Union[nn.Module, ModelWrapper],
|
||||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
checkpoint: str,
|
||||||
|
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
|
||||||
"""
|
"""
|
||||||
Load model from checkpoint.
|
Load model from checkpoint.
|
||||||
|
|
||||||
@ -59,14 +61,26 @@ class CheckpointIO(ABC):
|
|||||||
strict (bool): whether to strictly enforce that the param name in
|
strict (bool): whether to strictly enforce that the param name in
|
||||||
the checkpoint match the keys returned by this module's.
|
the checkpoint match the keys returned by this module's.
|
||||||
"""
|
"""
|
||||||
pass
|
ckpt_path = Path(checkpoint)
|
||||||
|
is_sharded = self.is_sharded_checkpoint(ckpt_path)
|
||||||
|
|
||||||
|
origin_model = model
|
||||||
|
|
||||||
|
if isinstance(model, ModelWrapper):
|
||||||
|
model = model.unwrap()
|
||||||
|
|
||||||
|
if is_sharded:
|
||||||
|
self.load_sharded_model(model, ckpt_path, strict)
|
||||||
|
else:
|
||||||
|
self.load_unsharded_model(model, ckpt_path, strict)
|
||||||
|
|
||||||
|
return origin_model
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save_model(self,
|
def save_model(self,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, ModelWrapper],
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
prefix: str = None,
|
|
||||||
shard: bool = False,
|
shard: bool = False,
|
||||||
|
prefix: str = None,
|
||||||
size_per_shard: int = 1024):
|
size_per_shard: int = 1024):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint.
|
Save model to checkpoint.
|
||||||
@ -83,17 +97,24 @@ class CheckpointIO(ABC):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): model to be saved.
|
model (nn.Module): model to be saved.
|
||||||
checkpoint: checkpoint path. The checkpoint path can be :
|
checkpoint (str): checkpoint path. The checkpoint path can be :
|
||||||
1. a file path, e.g. 'model.pt'
|
1. a file path, e.g. 'model.pt'
|
||||||
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
|
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
|
||||||
shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||||
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
||||||
that the checkpoint path is a directory path instead of a file path.
|
that the checkpoint path is a directory path instead of a file path.
|
||||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
|
||||||
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
if isinstance(model, ModelWrapper):
|
||||||
|
model = model.unwrap()
|
||||||
|
|
||||||
|
if shard:
|
||||||
|
self.save_sharded_model(model, checkpoint, prefix, size_per_shard)
|
||||||
|
else:
|
||||||
|
self.save_unsharded_model(model, checkpoint)
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||||
"""
|
"""
|
||||||
Load optimizer from checkpoint.
|
Load optimizer from checkpoint.
|
||||||
@ -102,19 +123,139 @@ class CheckpointIO(ABC):
|
|||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
|
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
|
||||||
"""
|
"""
|
||||||
pass
|
ckpt_path = Path(checkpoint)
|
||||||
|
is_sharded = self.is_sharded_checkpoint(ckpt_path)
|
||||||
|
|
||||||
@abstractmethod
|
if is_sharded:
|
||||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
self.load_sharded_optimizer(optimizer, ckpt_path)
|
||||||
|
else:
|
||||||
|
self.load_unsharded_optimizer(optimizer, ckpt_path)
|
||||||
|
|
||||||
|
def save_optimizer(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: str,
|
||||||
|
shard: bool = False,
|
||||||
|
prefix: str = None,
|
||||||
|
size_per_shard: int = 1024):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint.
|
Save optimizer to checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): optimizer to be saved.
|
optimizer (Optimizer): optimizer to be saved.
|
||||||
checkpoint: checkpoint path. The checkpoint path can be :
|
checkpoint (str): checkpoint path. The checkpoint path can be :
|
||||||
1. a file path, e.g. 'model.pt'
|
1. a file path, e.g. 'model.pt'
|
||||||
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
|
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
|
||||||
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
||||||
|
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||||
|
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
|
||||||
|
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||||
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||||
|
"""
|
||||||
|
if shard:
|
||||||
|
self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard)
|
||||||
|
else:
|
||||||
|
self.save_unsharded_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# Abstract methods for model loading/saving implementation
|
||||||
|
# ========================================================
|
||||||
|
@abstractmethod
|
||||||
|
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||||
|
"""
|
||||||
|
Load model from sharded checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): model to be loaded.
|
||||||
|
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||||
|
"""
|
||||||
|
Load model from unsharded checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): model to be loaded.
|
||||||
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||||
|
strict (bool): whether to strictly enforce that the param name in
|
||||||
|
the checkpoint match the keys returned by this module's.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||||
|
"""
|
||||||
|
Save model to sharded checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): model to be saved.
|
||||||
|
checkpoint (Path): checkpoint path. It should be a directory path.
|
||||||
|
prefix (str): prefix for the model checkpoint.
|
||||||
|
size_per_shard (int): size per shard in MB.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
|
||||||
|
"""
|
||||||
|
Save model to unsharded checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): model to be saved.
|
||||||
|
checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# Abstract methods for optimizer loading/saving implementation
|
||||||
|
# ========================================================
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||||
|
"""
|
||||||
|
Load optimizer from sharded checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
|
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||||
|
prefix (str): prefix for the optimizer checkpoint.
|
||||||
|
size_per_shard (int): size per shard in MB.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||||
|
"""
|
||||||
|
Load optimizer from unsharded checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||||
|
"""
|
||||||
|
Save optimizer to sharded checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): optimizer to be saved.
|
||||||
|
checkpoint (Path): checkpoint path. It should be a directory path.
|
||||||
|
prefix (str): prefix for the optimizer checkpoint.
|
||||||
|
size_per_shard (int): size per shard in MB.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||||
|
"""
|
||||||
|
Save optimizer to unsharded checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): optimizer to be saved.
|
||||||
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -10,57 +10,36 @@ __all__ = ['GeneralCheckpointIO']
|
|||||||
|
|
||||||
class GeneralCheckpointIO(CheckpointIO):
|
class GeneralCheckpointIO(CheckpointIO):
|
||||||
|
|
||||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||||
checkpoint = Path(checkpoint)
|
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint)
|
||||||
is_sharded = self.is_sharded_checkpoint(checkpoint)
|
|
||||||
|
|
||||||
if not is_sharded:
|
# iterate over the shard checkpoint files
|
||||||
checkpoint = self.load_state_dict(checkpoint)
|
# and load each
|
||||||
model.load_state_dict(checkpoint, strict=strict)
|
shard_files = self.get_checkpoint_shard_filenames(index_file_path)
|
||||||
else:
|
for shard_file in shard_files:
|
||||||
# find the index file
|
shard_checkpoint = self.load_state_dict(shard_file)
|
||||||
checkpoint_path = Path(checkpoint)
|
model.load_state_dict(shard_checkpoint, strict=strict)
|
||||||
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint_path)
|
|
||||||
|
|
||||||
# iterate over the shard checkpoint files
|
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||||
# and load each
|
checkpoint = self.load_state_dict(str(checkpoint))
|
||||||
shard_files = self.get_checkpoint_shard_filenames(index_file_path)
|
model.load_state_dict(checkpoint, strict=strict)
|
||||||
for shard_file in shard_files:
|
|
||||||
shard_checkpoint = self.load_state_dict(shard_file)
|
|
||||||
model.load_state_dict(shard_checkpoint, strict=strict)
|
|
||||||
|
|
||||||
return model
|
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||||
|
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
|
||||||
|
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||||
|
|
||||||
def save_model(self,
|
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
|
||||||
model: nn.Module,
|
self.save_checkpoint(model.state_dict(), checkpoint)
|
||||||
checkpoint: str,
|
|
||||||
prefix: str = None,
|
|
||||||
shard: bool = False,
|
|
||||||
size_per_shard: int = 1024):
|
|
||||||
checkpoint = Path(checkpoint)
|
|
||||||
if shard:
|
|
||||||
# TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
|
|
||||||
raise NotImplementedError("Not implemented yet")
|
|
||||||
else:
|
|
||||||
self.save_checkpoint(model.state_dict(), checkpoint)
|
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||||
checkpoint = Path(checkpoint)
|
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||||
is_sharded = self.is_sharded_checkpoint(checkpoint)
|
|
||||||
|
|
||||||
if not is_sharded:
|
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||||
checkpoint = self.load_state_dict(checkpoint)
|
checkpoint = self.load_state_dict(checkpoint)
|
||||||
optimizer.load_state_dict(checkpoint)
|
optimizer.load_state_dict(checkpoint)
|
||||||
else:
|
|
||||||
# TODO(FrankLeeeee): implement checkpoint loading from sharded checkpoint
|
|
||||||
# This is not an urgent feature, so we can leave it for later
|
|
||||||
# let's implement this when we test large-scale models
|
|
||||||
pass
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||||
if shard:
|
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||||
# TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
|
|
||||||
pass
|
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||||
else:
|
self.save_checkpoint(optimizer.state_dict(), checkpoint)
|
||||||
self.save_checkpoint(optimizer.state_dict(), checkpoint)
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
@ -141,12 +142,12 @@ class DistCoordinator(metaclass=SingletonMeta):
|
|||||||
should_block = rank != executor_rank
|
should_block = rank != executor_rank
|
||||||
|
|
||||||
if should_block:
|
if should_block:
|
||||||
dist.barrier(group=process_group)
|
self.block_all(process_group)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
if not should_block:
|
if not should_block:
|
||||||
dist.barrier(group=process_group)
|
self.block_all(process_group)
|
||||||
|
|
||||||
def destroy(self, process_group: ProcessGroup = None):
|
def destroy(self, process_group: ProcessGroup = None):
|
||||||
"""
|
"""
|
||||||
@ -156,3 +157,38 @@ class DistCoordinator(metaclass=SingletonMeta):
|
|||||||
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
|
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
|
||||||
"""
|
"""
|
||||||
dist.destroy_process_group(process_group)
|
dist.destroy_process_group(process_group)
|
||||||
|
|
||||||
|
def block_all(self, process_group: ProcessGroup = None):
|
||||||
|
"""
|
||||||
|
Block all processes in the process group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_group (ProcessGroup, optional): process group to block. Defaults to None, which refers to the default process group.
|
||||||
|
"""
|
||||||
|
dist.barrier(group=process_group)
|
||||||
|
|
||||||
|
def on_master_only(self, process_group: ProcessGroup = None):
|
||||||
|
"""
|
||||||
|
A function wrapper that only executes the wrapped function on the master process (rank 0).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from colossalai.cluster import DistCoordinator
|
||||||
|
>>> dist_coordinator = DistCoordinator()
|
||||||
|
>>>
|
||||||
|
>>> @dist_coordinator.on_master_only()
|
||||||
|
>>> def print_on_master(msg):
|
||||||
|
>>> print(msg)
|
||||||
|
"""
|
||||||
|
is_master = self.is_master(process_group)
|
||||||
|
|
||||||
|
# define an inner functiuon
|
||||||
|
def decorator(func):
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if is_master:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
4
colossalai/interface/__init__.py
Normal file
4
colossalai/interface/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .model import ModelWrapper
|
||||||
|
from .optimizer import OptimizerWrapper
|
||||||
|
|
||||||
|
__all__ = ['OptimizerWrapper', 'ModelWrapper']
|
25
colossalai/interface/model.py
Normal file
25
colossalai/interface/model.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWrapper(nn.Module):
|
||||||
|
"""
|
||||||
|
A wrapper class to define the common interface used by booster.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The model to be wrapped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
def unwrap(self):
|
||||||
|
"""
|
||||||
|
Unwrap the model to return the original model for checkpoint saving/loading.
|
||||||
|
"""
|
||||||
|
if isinstance(self.module, ModelWrapper):
|
||||||
|
return self.module.unwrap()
|
||||||
|
return self.module
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.module(*args, **kwargs)
|
5
examples/tutorial/new_api/README.md
Normal file
5
examples/tutorial/new_api/README.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# New API Features
|
||||||
|
|
||||||
|
**The New API is not officially released yet.**
|
||||||
|
|
||||||
|
This folder contains some of the demonstrations of the new API. The new API is still under intensive development and will be released soon.
|
2
examples/tutorial/new_api/test_ci.sh
Normal file
2
examples/tutorial/new_api/test_ci.sh
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
#!/usr/bin/env
|
||||||
|
echo "The CI integration will be completed when the API is stable"
|
4
examples/tutorial/new_api/torch_ddp/.gitignore
vendored
Normal file
4
examples/tutorial/new_api/torch_ddp/.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
data
|
||||||
|
checkpoint
|
||||||
|
ckpt-fp16
|
||||||
|
ckpt-fp32
|
44
examples/tutorial/new_api/torch_ddp/README.md
Normal file
44
examples/tutorial/new_api/torch_ddp/README.md
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Distributed Data Parallel
|
||||||
|
|
||||||
|
## 🚀 Quick Start
|
||||||
|
|
||||||
|
This example provides a training script and and evaluation script. The training script provides a an example of training ResNet on CIFAR10 dataset from scratch.
|
||||||
|
|
||||||
|
- Training Arguments
|
||||||
|
- `-r, `--resume`: resume from checkpoint file path
|
||||||
|
- `-c`, `--checkpoint`: the folder to save checkpoints
|
||||||
|
- `-i`, `--interval`: epoch interval to save checkpoints
|
||||||
|
- `-f`, `--fp16`: use fp16
|
||||||
|
|
||||||
|
- Eval Arguments
|
||||||
|
- `-e`, `--epoch`: select the epoch to evaluate
|
||||||
|
- `-c`, `--checkpoint`: the folder where checkpoints are found
|
||||||
|
|
||||||
|
|
||||||
|
### Train
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# train with torch DDP with fp32
|
||||||
|
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32
|
||||||
|
|
||||||
|
# train with torch DDP with mixed precision training
|
||||||
|
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 --fp16
|
||||||
|
```
|
||||||
|
|
||||||
|
### Eval
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# evaluate fp32 training
|
||||||
|
python eval.py -c ./ckpt-fp32 -e 80
|
||||||
|
|
||||||
|
# evaluate fp16 mixed precision training
|
||||||
|
python eval.py -c ./ckpt-fp16 -e 80
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected accuracy performance will be:
|
||||||
|
|
||||||
|
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 |
|
||||||
|
| --------- | ------------------------ | --------------------- | --------------------- |
|
||||||
|
| ResNet-18 | 85.85% | 85.03% | 85.12% |
|
||||||
|
|
||||||
|
**Note: the baseline is a adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
|
48
examples/tutorial/new_api/torch_ddp/eval.py
Normal file
48
examples/tutorial/new_api/torch_ddp/eval.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Parse Arguments
|
||||||
|
# ==============================
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint")
|
||||||
|
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Prepare Test Dataset
|
||||||
|
# ==============================
|
||||||
|
# CIFAR-10 dataset
|
||||||
|
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor())
|
||||||
|
|
||||||
|
# Data loader
|
||||||
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Load Model
|
||||||
|
# ==============================
|
||||||
|
model = torchvision.models.resnet18(num_classes=10).cuda()
|
||||||
|
state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth')
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Run Evaluation
|
||||||
|
# ==============================
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for images, labels in test_loader:
|
||||||
|
images = images.cuda()
|
||||||
|
labels = labels.cuda()
|
||||||
|
outputs = model(images)
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
|
128
examples/tutorial/new_api/torch_ddp/train.py
Normal file
128
examples/tutorial/new_api/torch_ddp/train.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from torch.optim.lr_scheduler import MultiStepLR
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import TorchDDPPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Parse Arguments
|
||||||
|
# ==============================
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
|
||||||
|
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
|
||||||
|
parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint")
|
||||||
|
parser.add_argument('-f', '--fp16', action='store_true', help="use fp16")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Prepare Checkpoint Directory
|
||||||
|
# ==============================
|
||||||
|
Path(args.checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Prepare Hyperparameters
|
||||||
|
# ==============================
|
||||||
|
NUM_EPOCHS = 80
|
||||||
|
LEARNING_RATE = 1e-3
|
||||||
|
START_EPOCH = args.resume if args.resume >= 0 else 0
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Launch Distributed Environment
|
||||||
|
# ==============================
|
||||||
|
colossalai.launch_from_torch(config={})
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
# update the learning rate with linear scaling
|
||||||
|
# old_gpu_num / old_lr = new_gpu_num / new_lr
|
||||||
|
LEARNING_RATE *= coordinator.world_size
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Prepare Booster
|
||||||
|
# ==============================
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
if args.fp16:
|
||||||
|
booster = Booster(mixed_precision='fp16', plugin=plugin)
|
||||||
|
else:
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Prepare Train Dataset
|
||||||
|
# ==============================
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[transforms.Pad(4),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomCrop(32),
|
||||||
|
transforms.ToTensor()])
|
||||||
|
|
||||||
|
# CIFAR-10 dataset
|
||||||
|
with coordinator.priority_execution():
|
||||||
|
train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=True)
|
||||||
|
|
||||||
|
# ====================================
|
||||||
|
# Prepare model, optimizer, criterion
|
||||||
|
# ====================================
|
||||||
|
# resent50
|
||||||
|
model = torchvision.models.resnet18(num_classes=10).cuda()
|
||||||
|
|
||||||
|
# Loss and optimizer
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
# lr scheduler
|
||||||
|
lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3)
|
||||||
|
|
||||||
|
# prepare dataloader with torch ddp plugin
|
||||||
|
train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=100, shuffle=True)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Resume from checkpoint
|
||||||
|
# ==============================
|
||||||
|
if args.resume >= 0:
|
||||||
|
booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth')
|
||||||
|
booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth')
|
||||||
|
booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth')
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Boost with ColossalAI
|
||||||
|
# ==============================
|
||||||
|
model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, optimizer, criterion,
|
||||||
|
train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Train model
|
||||||
|
# ==============================
|
||||||
|
total_step = len(train_dataloader)
|
||||||
|
|
||||||
|
for epoch in range(START_EPOCH, NUM_EPOCHS):
|
||||||
|
for i, (images, labels) in enumerate(train_dataloader):
|
||||||
|
images = images.cuda()
|
||||||
|
labels = labels.cuda()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = model(images)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
# Backward and optimize
|
||||||
|
optimizer.zero_grad()
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if (i + 1) % 100 == 0:
|
||||||
|
print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch + 1, NUM_EPOCHS, i + 1, total_step,
|
||||||
|
loss.item()))
|
||||||
|
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
# save checkpoint every 5 epoch
|
||||||
|
if (epoch + 1) % args.interval == 0:
|
||||||
|
booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth')
|
||||||
|
booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth')
|
||||||
|
booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth')
|
@ -8,8 +8,8 @@ from torch.optim import SGD
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.interface import OptimizerWrapper
|
|
||||||
from colossalai.booster.plugin import TorchDDPPlugin
|
from colossalai.booster.plugin import TorchDDPPlugin
|
||||||
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
@ -34,7 +34,7 @@ def check_torch_ddp_plugin():
|
|||||||
|
|
||||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||||
|
|
||||||
assert isinstance(model, DDP)
|
assert isinstance(model.module, DDP)
|
||||||
assert isinstance(optimizer, OptimizerWrapper)
|
assert isinstance(optimizer, OptimizerWrapper)
|
||||||
|
|
||||||
output = model(**data)
|
output = model(**data)
|
||||||
|
@ -42,8 +42,8 @@ def test_unsharded_checkpoint():
|
|||||||
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
# load the model and optimizer
|
# load the model and optimizer
|
||||||
new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
|
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
|
||||||
new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
||||||
|
|
||||||
# do recursive check for the optimizer state dict
|
# do recursive check for the optimizer state dict
|
||||||
# if the value is a dict, compare its values
|
# if the value is a dict, compare its values
|
||||||
|
Loading…
Reference in New Issue
Block a user