1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-03 05:58:09 +00:00

[booster] implemented the torch ddd + resnet example ()

* [booster] implemented the torch ddd + resnet example

* polish code
This commit is contained in:
Frank Lee 2023-03-27 10:24:14 +08:00 committed by GitHub
parent 1a229045af
commit 73d3e4d309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 608 additions and 128 deletions

View File

@ -1,4 +1,3 @@
from .accelerator import Accelerator from .accelerator import Accelerator
from .booster import Booster from .booster import Booster
from .environment_table import EnvironmentTable
from .plugin import Plugin from .plugin import Plugin

View File

@ -8,6 +8,8 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.checkpoint_io import GeneralCheckpointIO
from .accelerator import Accelerator from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin from .plugin import Plugin
@ -61,19 +63,21 @@ class Booster:
self.plugin = plugin self.plugin = plugin
# set accelerator # set accelerator
if self.plugin and self.plugin.control_device: if self.plugin and self.plugin.control_device():
self.accelerator = None self.accelerator = None
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
else: else:
self.accelerator = Accelerator(device) self.accelerator = Accelerator(device)
# set precision # set precision
if mixed_precision is None or (self.plugin and self.plugin.control_precision): if self.plugin and self.plugin.control_precision():
self.mixed_precision = None
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
else: else:
# validate and set precision # validate and set precision
if isinstance(MixedPrecision, str): if isinstance(mixed_precision, str):
# the user will take the default arguments for amp training # the user will take the default arguments for amp training
self.mixed_precision = mixed_precision_factory(mixed_precision) self.mixed_precision = mixed_precision_factory(mixed_precision)
elif isinstance(mixed_precision, MixedPrecision): elif isinstance(mixed_precision, MixedPrecision):
@ -84,6 +88,11 @@ class Booster:
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
) )
if self.plugin is not None and self.plugin.control_checkpoint_io():
self.checkpoint_io = self.plugin.get_checkpoint_io()
else:
self.checkpoint_io = GeneralCheckpointIO()
def boost( def boost(
self, self,
model: nn.Module, model: nn.Module,
@ -109,12 +118,13 @@ class Booster:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
model, optimizer, criterion, dataloader, lr_scheduler) model, optimizer, criterion, dataloader, lr_scheduler)
if self.plugin and not self.plugin.control_device: if self.plugin and not self.plugin.control_device():
# transform model for accelerator # transform model for accelerator
model = self.accelerator.configure(model) model = self.accelerator.configure(model)
if self.mixed_precision and self.plugin and not self.plugin.control_precision: if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
# transform model for mixed precision # transform model for mixed precision
# when mixed_precision is specified and the plugin is not given or does not control the precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
@ -140,18 +150,25 @@ class Booster:
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model) return self.plugin.no_sync(model)
def save(self, def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
obj: Union[nn.Module, Optimizer, LRScheduler], self.checkpoint_io.load_model(model, checkpoint, strict)
path_like: str,
plan: str = 'torch',
**kwargs) -> None:
# TODO: implement this method
pass
def load(self, def save_model(self,
obj: Union[nn.Module, Optimizer, LRScheduler], model: nn.Module,
path_like: str, checkpoint: str,
plan: str = 'torch', prefix: str = None,
**kwargs) -> None: shard: bool = False,
# TODO: implement this method size_per_shard: int = 1024):
pass self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)

View File

@ -1,18 +0,0 @@
from typing import List
__all__ = ['EnvironmentTable']
class EnvironmentTable:
def __init__(self, intra_op_world_sizes: List[int]):
# TODO: implement this method
pass
@property
def is_master(self) -> bool:
# TODO: implement this method
pass
# TODO: implement more utility methods as given in
# https://github.com/hpcaitech/ColossalAI/issues/3051

View File

@ -1,3 +0,0 @@
from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper']

View File

@ -5,7 +5,8 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from ..interface import OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from .mixed_precision_base import MixedPrecision from .mixed_precision_base import MixedPrecision
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule'] __all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
@ -45,7 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
scaled_loss.backward(*args, **kwargs) scaled_loss.backward(*args, **kwargs)
def step(self, *args, **kwargs) -> Optional[float]: def step(self, *args, **kwargs) -> Optional[float]:
return self.scaler.step(self.optim, *args, **kwargs) out = self.scaler.step(self.optim, *args, **kwargs)
self.scaler.update()
return out
def scale_loss(self, loss: Tensor) -> Tensor: def scale_loss(self, loss: Tensor) -> Tensor:
return self.scaler.scale(loss) return self.scaler.scale(loss)
@ -67,7 +70,7 @@ class TorchAMPOptimizer(OptimizerWrapper):
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs) super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
class TorchAMPModule(nn.Module): class TorchAMPModule(ModelWrapper):
""" """
Module wrapper for mixed precision training in FP16 using PyTorch AMP. Module wrapper for mixed precision training in FP16 using PyTorch AMP.
@ -76,8 +79,7 @@ class TorchAMPModule(nn.Module):
""" """
def __init__(self, module: nn.Module): def __init__(self, module: nn.Module):
super().__init__() super().__init__(module)
self.module = module
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():

View File

@ -4,7 +4,7 @@ from typing import Callable, Tuple
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from ..interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
class MixedPrecision(ABC): class MixedPrecision(ABC):

View File

@ -6,34 +6,30 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.booster.interface import OptimizerWrapper from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
__all__ = ['Plugin'] __all__ = ['Plugin']
class Plugin(ABC): class Plugin(ABC):
@property
@abstractmethod @abstractmethod
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
pass pass
@property
@abstractmethod @abstractmethod
def supported_precisions(self) -> List[str]: def supported_precisions(self) -> List[str]:
pass pass
@property
@abstractmethod @abstractmethod
def control_precision(self) -> bool: def control_precision(self) -> bool:
pass pass
@property
@abstractmethod @abstractmethod
def control_device(self) -> bool: def control_device(self) -> bool:
pass pass
@property
@abstractmethod @abstractmethod
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:
pass pass
@ -49,3 +45,17 @@ class Plugin(ABC):
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
# implement this method # implement this method
pass pass
@abstractmethod
def control_checkpoint_io(self) -> bool:
"""
Whether the plugin controls the checkpoint io
"""
pass
@abstractmethod
def get_checkpoint_io(self) -> CheckpointIO:
"""
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
pass

View File

@ -11,13 +11,61 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from colossalai.booster.interface import OptimizerWrapper from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .plugin_base import Plugin from .plugin_base import Plugin
__all__ = ['TorchDDPPlugin'] __all__ = ['TorchDDPPlugin']
class TorchDDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Save optimizer to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = DDP(module, *args, **kwargs)
def unwrap(self):
return self.module.module
class TorchDDPPlugin(Plugin): class TorchDDPPlugin(Plugin):
""" """
Plugin for PyTorch DDP. Plugin for PyTorch DDP.
@ -138,10 +186,19 @@ class TorchDDPPlugin(Plugin):
# cast model to cuda # cast model to cuda
model = model.cuda() model = model.cuda()
# convert model to sync bn
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with PyTorch DDP # wrap the model with PyTorch DDP
model = DDP(model, **self.ddp_kwargs) model = TorchDDPModel(model, **self.ddp_kwargs)
if not isinstance(optimizer, OptimizerWrapper): if not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer) optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True
def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()

View File

@ -1,13 +1,15 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile'] __all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
@ -37,15 +39,15 @@ class CheckpointIO(ABC):
>>> >>>
>>> # save optimizer to checkpoint >>> # save optimizer to checkpoint
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt') >>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
""" """
# ====================================== # ======================================
# Abstract methods for implementation # Public methods
# ====================================== # ======================================
def load_model(self,
@abstractmethod model: Union[nn.Module, ModelWrapper],
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): checkpoint: str,
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
""" """
Load model from checkpoint. Load model from checkpoint.
@ -59,14 +61,26 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's. the checkpoint match the keys returned by this module's.
""" """
pass ckpt_path = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(ckpt_path)
origin_model = model
if isinstance(model, ModelWrapper):
model = model.unwrap()
if is_sharded:
self.load_sharded_model(model, ckpt_path, strict)
else:
self.load_unsharded_model(model, ckpt_path, strict)
return origin_model
@abstractmethod
def save_model(self, def save_model(self,
model: nn.Module, model: Union[nn.Module, ModelWrapper],
checkpoint: str, checkpoint: str,
prefix: str = None,
shard: bool = False, shard: bool = False,
prefix: str = None,
size_per_shard: int = 1024): size_per_shard: int = 1024):
""" """
Save model to checkpoint. Save model to checkpoint.
@ -83,17 +97,24 @@ class CheckpointIO(ABC):
Args: Args:
model (nn.Module): model to be saved. model (nn.Module): model to be saved.
checkpoint: checkpoint path. The checkpoint path can be : checkpoint (str): checkpoint path. The checkpoint path can be :
1. a file path, e.g. 'model.pt' 1. a file path, e.g. 'model.pt'
2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True. 2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path. that the checkpoint path is a directory path instead of a file path.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
""" """
pass
@abstractmethod if isinstance(model, ModelWrapper):
model = model.unwrap()
if shard:
self.save_sharded_model(model, checkpoint, prefix, size_per_shard)
else:
self.save_unsharded_model(model, checkpoint)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str): def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
""" """
Load optimizer from checkpoint. Load optimizer from checkpoint.
@ -102,19 +123,139 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded. optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
""" """
pass ckpt_path = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(ckpt_path)
@abstractmethod if is_sharded:
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): self.load_sharded_optimizer(optimizer, ckpt_path)
else:
self.load_unsharded_optimizer(optimizer, ckpt_path)
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
prefix: str = None,
size_per_shard: int = 1024):
""" """
Save optimizer to checkpoint. Save optimizer to checkpoint.
Args: Args:
optimizer (Optimizer): optimizer to be saved. optimizer (Optimizer): optimizer to be saved.
checkpoint: checkpoint path. The checkpoint path can be : checkpoint (str): checkpoint path. The checkpoint path can be :
1. a file path, e.g. 'model.pt' 1. a file path, e.g. 'model.pt'
2. a path to a json file which defines the index to the sharded checkpoint for the optimizer 2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
3. a path to a folder containing a unique .index.json file for sharded checkpoint 3. a path to a folder containing a unique .index.json file for sharded checkpoint
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard)
else:
self.save_unsharded_optimizer(optimizer, checkpoint)
# ========================================================
# Abstract methods for model loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
"""
Load model from sharded checkpoint.
Args:
model (nn.Module): model to be loaded.
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
"""
pass
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
"""
Load model from unsharded checkpoint.
Args:
model (nn.Module): model to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
"""
Save model to sharded checkpoint.
Args:
model (nn.Module): model to be saved.
checkpoint (Path): checkpoint path. It should be a directory path.
prefix (str): prefix for the model checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
"""
Save model to unsharded checkpoint.
Args:
model (nn.Module): model to be saved.
checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
"""
Load optimizer from sharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
"""
Load optimizer from unsharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass
@abstractmethod
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
"""
Save optimizer to sharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint (Path): checkpoint path. It should be a directory path.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
"""
Save optimizer to unsharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
""" """
pass pass

View File

@ -10,57 +10,36 @@ __all__ = ['GeneralCheckpointIO']
class GeneralCheckpointIO(CheckpointIO): class GeneralCheckpointIO(CheckpointIO):
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
checkpoint = Path(checkpoint) index_file_path = self.get_sharded_checkpoint_index_file(checkpoint)
is_sharded = self.is_sharded_checkpoint(checkpoint)
if not is_sharded: # iterate over the shard checkpoint files
checkpoint = self.load_state_dict(checkpoint) # and load each
model.load_state_dict(checkpoint, strict=strict) shard_files = self.get_checkpoint_shard_filenames(index_file_path)
else: for shard_file in shard_files:
# find the index file shard_checkpoint = self.load_state_dict(shard_file)
checkpoint_path = Path(checkpoint) model.load_state_dict(shard_checkpoint, strict=strict)
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint_path)
# iterate over the shard checkpoint files def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
# and load each checkpoint = self.load_state_dict(str(checkpoint))
shard_files = self.get_checkpoint_shard_filenames(index_file_path) model.load_state_dict(checkpoint, strict=strict)
for shard_file in shard_files:
shard_checkpoint = self.load_state_dict(shard_file)
model.load_state_dict(shard_checkpoint, strict=strict)
return model def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_model(self, def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
model: nn.Module, self.save_checkpoint(model.state_dict(), checkpoint)
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
checkpoint = Path(checkpoint)
if shard:
# TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
raise NotImplementedError("Not implemented yet")
else:
self.save_checkpoint(model.state_dict(), checkpoint)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str): def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
checkpoint = Path(checkpoint) raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
is_sharded = self.is_sharded_checkpoint(checkpoint)
if not is_sharded: def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = self.load_state_dict(checkpoint) checkpoint = self.load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint) optimizer.load_state_dict(checkpoint)
else:
# TODO(FrankLeeeee): implement checkpoint loading from sharded checkpoint
# This is not an urgent feature, so we can leave it for later
# let's implement this when we test large-scale models
pass
return optimizer
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
if shard: raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
# TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
pass def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
else: self.save_checkpoint(optimizer.state_dict(), checkpoint)
self.save_checkpoint(optimizer.state_dict(), checkpoint)

View File

@ -1,3 +1,4 @@
import functools
import os import os
from contextlib import contextmanager from contextlib import contextmanager
@ -141,12 +142,12 @@ class DistCoordinator(metaclass=SingletonMeta):
should_block = rank != executor_rank should_block = rank != executor_rank
if should_block: if should_block:
dist.barrier(group=process_group) self.block_all(process_group)
yield yield
if not should_block: if not should_block:
dist.barrier(group=process_group) self.block_all(process_group)
def destroy(self, process_group: ProcessGroup = None): def destroy(self, process_group: ProcessGroup = None):
""" """
@ -156,3 +157,38 @@ class DistCoordinator(metaclass=SingletonMeta):
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group. process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
""" """
dist.destroy_process_group(process_group) dist.destroy_process_group(process_group)
def block_all(self, process_group: ProcessGroup = None):
"""
Block all processes in the process group.
Args:
process_group (ProcessGroup, optional): process group to block. Defaults to None, which refers to the default process group.
"""
dist.barrier(group=process_group)
def on_master_only(self, process_group: ProcessGroup = None):
"""
A function wrapper that only executes the wrapped function on the master process (rank 0).
Example:
>>> from colossalai.cluster import DistCoordinator
>>> dist_coordinator = DistCoordinator()
>>>
>>> @dist_coordinator.on_master_only()
>>> def print_on_master(msg):
>>> print(msg)
"""
is_master = self.is_master(process_group)
# define an inner functiuon
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master:
return func(*args, **kwargs)
return wrapper
return decorator

View File

@ -0,0 +1,4 @@
from .model import ModelWrapper
from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper', 'ModelWrapper']

View File

@ -0,0 +1,25 @@
import torch.nn as nn
class ModelWrapper(nn.Module):
"""
A wrapper class to define the common interface used by booster.
Args:
module (nn.Module): The model to be wrapped.
"""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def unwrap(self):
"""
Unwrap the model to return the original model for checkpoint saving/loading.
"""
if isinstance(self.module, ModelWrapper):
return self.module.unwrap()
return self.module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

View File

@ -0,0 +1,5 @@
# New API Features
**The New API is not officially released yet.**
This folder contains some of the demonstrations of the new API. The new API is still under intensive development and will be released soon.

View File

@ -0,0 +1,2 @@
#!/usr/bin/env
echo "The CI integration will be completed when the API is stable"

View File

@ -0,0 +1,4 @@
data
checkpoint
ckpt-fp16
ckpt-fp32

View File

@ -0,0 +1,44 @@
# Distributed Data Parallel
## 🚀 Quick Start
This example provides a training script and and evaluation script. The training script provides a an example of training ResNet on CIFAR10 dataset from scratch.
- Training Arguments
- `-r, `--resume`: resume from checkpoint file path
- `-c`, `--checkpoint`: the folder to save checkpoints
- `-i`, `--interval`: epoch interval to save checkpoints
- `-f`, `--fp16`: use fp16
- Eval Arguments
- `-e`, `--epoch`: select the epoch to evaluate
- `-c`, `--checkpoint`: the folder where checkpoints are found
### Train
```bash
# train with torch DDP with fp32
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32
# train with torch DDP with mixed precision training
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 --fp16
```
### Eval
```bash
# evaluate fp32 training
python eval.py -c ./ckpt-fp32 -e 80
# evaluate fp16 mixed precision training
python eval.py -c ./ckpt-fp16 -e 80
```
Expected accuracy performance will be:
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 |
| --------- | ------------------------ | --------------------- | --------------------- |
| ResNet-18 | 85.85% | 85.03% | 85.12% |
**Note: the baseline is a adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**

View File

@ -0,0 +1,48 @@
import argparse
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint")
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
args = parser.parse_args()
# ==============================
# Prepare Test Dataset
# ==============================
# CIFAR-10 dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor())
# Data loader
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
# ==============================
# Load Model
# ==============================
model = torchvision.models.resnet18(num_classes=10).cuda()
state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth')
model.load_state_dict(state_dict)
# ==============================
# Run Evaluation
# ==============================
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.cuda()
labels = labels.cuda()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

View File

@ -0,0 +1,128 @@
import argparse
from pathlib import Path
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import MultiStepLR
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.cluster import DistCoordinator
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint")
parser.add_argument('-f', '--fp16', action='store_true', help="use fp16")
args = parser.parse_args()
# ==============================
# Prepare Checkpoint Directory
# ==============================
Path(args.checkpoint).mkdir(parents=True, exist_ok=True)
# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 80
LEARNING_RATE = 1e-3
START_EPOCH = args.resume if args.resume >= 0 else 0
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={})
coordinator = DistCoordinator()
# update the learning rate with linear scaling
# old_gpu_num / old_lr = new_gpu_num / new_lr
LEARNING_RATE *= coordinator.world_size
# ==============================
# Prepare Booster
# ==============================
plugin = TorchDDPPlugin()
if args.fp16:
booster = Booster(mixed_precision='fp16', plugin=plugin)
else:
booster = Booster(plugin=plugin)
# ==============================
# Prepare Train Dataset
# ==============================
transform = transforms.Compose(
[transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor()])
# CIFAR-10 dataset
with coordinator.priority_execution():
train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=True)
# ====================================
# Prepare model, optimizer, criterion
# ====================================
# resent50
model = torchvision.models.resnet18(num_classes=10).cuda()
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# lr scheduler
lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3)
# prepare dataloader with torch ddp plugin
train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=100, shuffle=True)
# ==============================
# Resume from checkpoint
# ==============================
if args.resume >= 0:
booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth')
booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth')
booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth')
# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, optimizer, criterion,
train_dataloader, lr_scheduler)
# ==============================
# Train model
# ==============================
total_step = len(train_dataloader)
for epoch in range(START_EPOCH, NUM_EPOCHS):
for i, (images, labels) in enumerate(train_dataloader):
images = images.cuda()
labels = labels.cuda()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
booster.backward(loss, optimizer)
optimizer.step()
if (i + 1) % 100 == 0:
print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch + 1, NUM_EPOCHS, i + 1, total_step,
loss.item()))
lr_scheduler.step()
# save checkpoint every 5 epoch
if (epoch + 1) % args.interval == 0:
booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth')
booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth')
booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth')

View File

@ -8,8 +8,8 @@ from torch.optim import SGD
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.interface import OptimizerWrapper
from colossalai.booster.plugin import TorchDDPPlugin from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
@ -34,7 +34,7 @@ def check_torch_ddp_plugin():
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
assert isinstance(model, DDP) assert isinstance(model.module, DDP)
assert isinstance(optimizer, OptimizerWrapper) assert isinstance(optimizer, OptimizerWrapper)
output = model(**data) output = model(**data)

View File

@ -42,8 +42,8 @@ def test_unsharded_checkpoint():
new_optimizer = Adam(new_model.parameters(), lr=0.001) new_optimizer = Adam(new_model.parameters(), lr=0.001)
# load the model and optimizer # load the model and optimizer
new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name) ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
# do recursive check for the optimizer state dict # do recursive check for the optimizer state dict
# if the value is a dict, compare its values # if the value is a dict, compare its values