diff --git a/colossalai/booster/__init__.py b/colossalai/booster/__init__.py index 3b3f45bb0..841054a9c 100644 --- a/colossalai/booster/__init__.py +++ b/colossalai/booster/__init__.py @@ -1,4 +1,3 @@ from .accelerator import Accelerator from .booster import Booster -from .environment_table import EnvironmentTable from .plugin import Plugin diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 230c65a9e..1ad9f7f20 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -8,6 +8,8 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from colossalai.checkpoint_io import GeneralCheckpointIO + from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory from .plugin import Plugin @@ -61,19 +63,21 @@ class Booster: self.plugin = plugin # set accelerator - if self.plugin and self.plugin.control_device: + if self.plugin and self.plugin.control_device(): self.accelerator = None warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') else: self.accelerator = Accelerator(device) # set precision - if mixed_precision is None or (self.plugin and self.plugin.control_precision): - self.mixed_precision = None + if self.plugin and self.plugin.control_precision(): warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') + self.mixed_precision = None + elif mixed_precision is None: + self.mixed_precision = None else: # validate and set precision - if isinstance(MixedPrecision, str): + if isinstance(mixed_precision, str): # the user will take the default arguments for amp training self.mixed_precision = mixed_precision_factory(mixed_precision) elif isinstance(mixed_precision, MixedPrecision): @@ -84,6 +88,11 @@ class Booster: f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' ) + if self.plugin is not None and self.plugin.control_checkpoint_io(): + self.checkpoint_io = self.plugin.get_checkpoint_io() + else: + self.checkpoint_io = GeneralCheckpointIO() + def boost( self, model: nn.Module, @@ -109,12 +118,13 @@ class Booster: model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( model, optimizer, criterion, dataloader, lr_scheduler) - if self.plugin and not self.plugin.control_device: + if self.plugin and not self.plugin.control_device(): # transform model for accelerator model = self.accelerator.configure(model) - if self.mixed_precision and self.plugin and not self.plugin.control_precision: + if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()): # transform model for mixed precision + # when mixed_precision is specified and the plugin is not given or does not control the precision model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) return model, optimizer, criterion, dataloader, lr_scheduler @@ -140,18 +150,25 @@ class Booster: assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' return self.plugin.no_sync(model) - def save(self, - obj: Union[nn.Module, Optimizer, LRScheduler], - path_like: str, - plan: str = 'torch', - **kwargs) -> None: - # TODO: implement this method - pass + def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + self.checkpoint_io.load_model(model, checkpoint, strict) - def load(self, - obj: Union[nn.Module, Optimizer, LRScheduler], - path_like: str, - plan: str = 'torch', - **kwargs) -> None: - # TODO: implement this method - pass + def save_model(self, + model: nn.Module, + checkpoint: str, + prefix: str = None, + shard: bool = False, + size_per_shard: int = 1024): + self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard) + + def load_optimizer(self, optimizer: Optimizer, checkpoint: str): + self.checkpoint_io.load_optimizer(optimizer, checkpoint) + + def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): + self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard) + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint) + + def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint) diff --git a/colossalai/booster/environment_table.py b/colossalai/booster/environment_table.py deleted file mode 100644 index 4b16f120c..000000000 --- a/colossalai/booster/environment_table.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import List - -__all__ = ['EnvironmentTable'] - - -class EnvironmentTable: - - def __init__(self, intra_op_world_sizes: List[int]): - # TODO: implement this method - pass - - @property - def is_master(self) -> bool: - # TODO: implement this method - pass - - # TODO: implement more utility methods as given in - # https://github.com/hpcaitech/ColossalAI/issues/3051 diff --git a/colossalai/booster/interface/__init__.py b/colossalai/booster/interface/__init__.py deleted file mode 100644 index 8892a13e1..000000000 --- a/colossalai/booster/interface/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .optimizer import OptimizerWrapper - -__all__ = ['OptimizerWrapper'] diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 054f78d2e..9999aa5e0 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -5,7 +5,8 @@ import torch.nn as nn from torch import Tensor from torch.optim import Optimizer -from ..interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper + from .mixed_precision_base import MixedPrecision __all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule'] @@ -45,7 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper): scaled_loss.backward(*args, **kwargs) def step(self, *args, **kwargs) -> Optional[float]: - return self.scaler.step(self.optim, *args, **kwargs) + out = self.scaler.step(self.optim, *args, **kwargs) + self.scaler.update() + return out def scale_loss(self, loss: Tensor) -> Tensor: return self.scaler.scale(loss) @@ -67,7 +70,7 @@ class TorchAMPOptimizer(OptimizerWrapper): super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs) -class TorchAMPModule(nn.Module): +class TorchAMPModule(ModelWrapper): """ Module wrapper for mixed precision training in FP16 using PyTorch AMP. @@ -76,8 +79,7 @@ class TorchAMPModule(nn.Module): """ def __init__(self, module: nn.Module): - super().__init__() - self.module = module + super().__init__(module) def forward(self, *args, **kwargs): with torch.cuda.amp.autocast(): diff --git a/colossalai/booster/mixed_precision/mixed_precision_base.py b/colossalai/booster/mixed_precision/mixed_precision_base.py index d1e8acc82..2490e9811 100644 --- a/colossalai/booster/mixed_precision/mixed_precision_base.py +++ b/colossalai/booster/mixed_precision/mixed_precision_base.py @@ -4,7 +4,7 @@ from typing import Callable, Tuple import torch.nn as nn from torch.optim import Optimizer -from ..interface import OptimizerWrapper +from colossalai.interface import OptimizerWrapper class MixedPrecision(ABC): diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index 3c347cb42..7a222022c 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -6,34 +6,30 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from colossalai.booster.interface import OptimizerWrapper +from colossalai.checkpoint_io import CheckpointIO +from colossalai.interface import OptimizerWrapper __all__ = ['Plugin'] class Plugin(ABC): - @property @abstractmethod def supported_devices(self) -> List[str]: pass - @property @abstractmethod def supported_precisions(self) -> List[str]: pass - @property @abstractmethod def control_precision(self) -> bool: pass - @property @abstractmethod def control_device(self) -> bool: pass - @property @abstractmethod def support_no_sync(self) -> bool: pass @@ -49,3 +45,17 @@ class Plugin(ABC): ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: # implement this method pass + + @abstractmethod + def control_checkpoint_io(self) -> bool: + """ + Whether the plugin controls the checkpoint io + """ + pass + + @abstractmethod + def get_checkpoint_io(self) -> CheckpointIO: + """ + Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. + """ + pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 07d6be8c7..d7f3d22d9 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -11,13 +11,61 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from colossalai.booster.interface import OptimizerWrapper +from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.cluster import DistCoordinator +from colossalai.interface import ModelWrapper, OptimizerWrapper from .plugin_base import Plugin __all__ = ['TorchDDPPlugin'] +class TorchDDPCheckpointIO(GeneralCheckpointIO): + + def __init__(self) -> None: + super().__init__() + self.coordinator = DistCoordinator() + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + """ + Load model from checkpoint with automatic unwrapping. + """ + # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + return super().load_unsharded_model(model, checkpoint, strict=strict) + + def save_unsharded_model(self, model: nn.Module, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + if self.coordinator.is_master(): + super().save_unsharded_model(model, checkpoint) + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): + """ + Save optimizer to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_unsharded_optimizer(optimizer, checkpoint) + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + +class TorchDDPModel(ModelWrapper): + + def __init__(self, module: nn.Module, *args, **kwargs) -> None: + super().__init__(module) + self.module = DDP(module, *args, **kwargs) + + def unwrap(self): + return self.module.module + + class TorchDDPPlugin(Plugin): """ Plugin for PyTorch DDP. @@ -138,10 +186,19 @@ class TorchDDPPlugin(Plugin): # cast model to cuda model = model.cuda() + # convert model to sync bn + model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) + # wrap the model with PyTorch DDP - model = DDP(model, **self.ddp_kwargs) + model = TorchDDPModel(model, **self.ddp_kwargs) if not isinstance(optimizer, OptimizerWrapper): optimizer = OptimizerWrapper(optimizer) return model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return TorchDDPCheckpointIO() diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 00a65424b..d6eef7a96 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -1,13 +1,15 @@ import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any +from typing import Any, Union import torch import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from colossalai.interface import ModelWrapper + __all__ = ['CheckpointIO', 'ShardCheckpointIndexFile'] @@ -37,15 +39,15 @@ class CheckpointIO(ABC): >>> >>> # save optimizer to checkpoint >>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt') - """ # ====================================== - # Abstract methods for implementation + # Public methods # ====================================== - - @abstractmethod - def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + def load_model(self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + strict: bool = True) -> Union[nn.Module, ModelWrapper]: """ Load model from checkpoint. @@ -59,14 +61,26 @@ class CheckpointIO(ABC): strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ - pass + ckpt_path = Path(checkpoint) + is_sharded = self.is_sharded_checkpoint(ckpt_path) + + origin_model = model + + if isinstance(model, ModelWrapper): + model = model.unwrap() + + if is_sharded: + self.load_sharded_model(model, ckpt_path, strict) + else: + self.load_unsharded_model(model, ckpt_path, strict) + + return origin_model - @abstractmethod def save_model(self, - model: nn.Module, + model: Union[nn.Module, ModelWrapper], checkpoint: str, - prefix: str = None, shard: bool = False, + prefix: str = None, size_per_shard: int = 1024): """ Save model to checkpoint. @@ -83,17 +97,24 @@ class CheckpointIO(ABC): Args: model (nn.Module): model to be saved. - checkpoint: checkpoint path. The checkpoint path can be : + checkpoint (str): checkpoint path. The checkpoint path can be : 1. a file path, e.g. 'model.pt' 2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True. - shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into + shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure that the checkpoint path is a directory path instead of a file path. - size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. + prefix (str): prefix for the model checkpoint file name when shard=True. Default: None. + size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. """ - pass - @abstractmethod + if isinstance(model, ModelWrapper): + model = model.unwrap() + + if shard: + self.save_sharded_model(model, checkpoint, prefix, size_per_shard) + else: + self.save_unsharded_model(model, checkpoint) + def load_optimizer(self, optimizer: Optimizer, checkpoint: str): """ Load optimizer from checkpoint. @@ -102,19 +123,139 @@ class CheckpointIO(ABC): optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the """ - pass + ckpt_path = Path(checkpoint) + is_sharded = self.is_sharded_checkpoint(ckpt_path) - @abstractmethod - def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): + if is_sharded: + self.load_sharded_optimizer(optimizer, ckpt_path) + else: + self.load_unsharded_optimizer(optimizer, ckpt_path) + + def save_optimizer(self, + optimizer: Optimizer, + checkpoint: str, + shard: bool = False, + prefix: str = None, + size_per_shard: int = 1024): """ Save optimizer to checkpoint. Args: optimizer (Optimizer): optimizer to be saved. - checkpoint: checkpoint path. The checkpoint path can be : + checkpoint (str): checkpoint path. The checkpoint path can be : 1. a file path, e.g. 'model.pt' 2. a path to a json file which defines the index to the sharded checkpoint for the optimizer 3. a path to a folder containing a unique .index.json file for sharded checkpoint + shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into + multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file. + prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. + size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. + """ + if shard: + self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard) + else: + self.save_unsharded_optimizer(optimizer, checkpoint) + + # ======================================================== + # Abstract methods for model loading/saving implementation + # ======================================================== + @abstractmethod + def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): + """ + Load model from sharded checkpoint. + + Args: + model (nn.Module): model to be loaded. + checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. + """ + pass + + @abstractmethod + def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): + """ + Load model from unsharded checkpoint. + + Args: + model (nn.Module): model to be loaded. + checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + strict (bool): whether to strictly enforce that the param name in + the checkpoint match the keys returned by this module's. + """ + pass + + @abstractmethod + def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int): + """ + Save model to sharded checkpoint. + + Args: + model (nn.Module): model to be saved. + checkpoint (Path): checkpoint path. It should be a directory path. + prefix (str): prefix for the model checkpoint. + size_per_shard (int): size per shard in MB. + """ + pass + + @abstractmethod + def save_unsharded_model(self, model: nn.Module, checkpoint: Path): + """ + Save model to unsharded checkpoint. + + Args: + model (nn.Module): model to be saved. + checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary. + """ + pass + + # ======================================================== + # Abstract methods for optimizer loading/saving implementation + # ======================================================== + + @abstractmethod + def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + """ + Load optimizer from sharded checkpoint. + + Args: + optimizer (Optimizer): optimizer to be loaded. + checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. + prefix (str): prefix for the optimizer checkpoint. + size_per_shard (int): size per shard in MB. + """ + pass + + @abstractmethod + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + """ + Load optimizer from unsharded checkpoint. + + Args: + optimizer (Optimizer): optimizer to be loaded. + checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + """ + pass + + @abstractmethod + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + """ + Save optimizer to sharded checkpoint. + + Args: + optimizer (Optimizer): optimizer to be saved. + checkpoint (Path): checkpoint path. It should be a directory path. + prefix (str): prefix for the optimizer checkpoint. + size_per_shard (int): size per shard in MB. + """ + pass + + @abstractmethod + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + """ + Save optimizer to unsharded checkpoint. + + Args: + optimizer (Optimizer): optimizer to be saved. + checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. """ pass diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 0a3636655..cfabcfa55 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -10,57 +10,36 @@ __all__ = ['GeneralCheckpointIO'] class GeneralCheckpointIO(CheckpointIO): - def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): - checkpoint = Path(checkpoint) - is_sharded = self.is_sharded_checkpoint(checkpoint) + def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): + index_file_path = self.get_sharded_checkpoint_index_file(checkpoint) - if not is_sharded: - checkpoint = self.load_state_dict(checkpoint) - model.load_state_dict(checkpoint, strict=strict) - else: - # find the index file - checkpoint_path = Path(checkpoint) - index_file_path = self.get_sharded_checkpoint_index_file(checkpoint_path) + # iterate over the shard checkpoint files + # and load each + shard_files = self.get_checkpoint_shard_filenames(index_file_path) + for shard_file in shard_files: + shard_checkpoint = self.load_state_dict(shard_file) + model.load_state_dict(shard_checkpoint, strict=strict) - # iterate over the shard checkpoint files - # and load each - shard_files = self.get_checkpoint_shard_filenames(index_file_path) - for shard_file in shard_files: - shard_checkpoint = self.load_state_dict(shard_file) - model.load_state_dict(shard_checkpoint, strict=strict) + def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): + checkpoint = self.load_state_dict(str(checkpoint)) + model.load_state_dict(checkpoint, strict=strict) - return model + def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int): + # TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model + raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def save_model(self, - model: nn.Module, - checkpoint: str, - prefix: str = None, - shard: bool = False, - size_per_shard: int = 1024): - checkpoint = Path(checkpoint) - if shard: - # TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint - raise NotImplementedError("Not implemented yet") - else: - self.save_checkpoint(model.state_dict(), checkpoint) + def save_unsharded_model(self, model: nn.Module, checkpoint: Path): + self.save_checkpoint(model.state_dict(), checkpoint) - def load_optimizer(self, optimizer: Optimizer, checkpoint: str): - checkpoint = Path(checkpoint) - is_sharded = self.is_sharded_checkpoint(checkpoint) + def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") - if not is_sharded: - checkpoint = self.load_state_dict(checkpoint) - optimizer.load_state_dict(checkpoint) - else: - # TODO(FrankLeeeee): implement checkpoint loading from sharded checkpoint - # This is not an urgent feature, so we can leave it for later - # let's implement this when we test large-scale models - pass - return optimizer + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + checkpoint = self.load_state_dict(checkpoint) + optimizer.load_state_dict(checkpoint) - def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): - if shard: - # TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint - pass - else: - self.save_checkpoint(optimizer.state_dict(), checkpoint) + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + self.save_checkpoint(optimizer.state_dict(), checkpoint) diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py index 6b48faf5b..99dde810e 100644 --- a/colossalai/cluster/dist_coordinator.py +++ b/colossalai/cluster/dist_coordinator.py @@ -1,3 +1,4 @@ +import functools import os from contextlib import contextmanager @@ -141,12 +142,12 @@ class DistCoordinator(metaclass=SingletonMeta): should_block = rank != executor_rank if should_block: - dist.barrier(group=process_group) + self.block_all(process_group) yield if not should_block: - dist.barrier(group=process_group) + self.block_all(process_group) def destroy(self, process_group: ProcessGroup = None): """ @@ -156,3 +157,38 @@ class DistCoordinator(metaclass=SingletonMeta): process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group. """ dist.destroy_process_group(process_group) + + def block_all(self, process_group: ProcessGroup = None): + """ + Block all processes in the process group. + + Args: + process_group (ProcessGroup, optional): process group to block. Defaults to None, which refers to the default process group. + """ + dist.barrier(group=process_group) + + def on_master_only(self, process_group: ProcessGroup = None): + """ + A function wrapper that only executes the wrapped function on the master process (rank 0). + + Example: + >>> from colossalai.cluster import DistCoordinator + >>> dist_coordinator = DistCoordinator() + >>> + >>> @dist_coordinator.on_master_only() + >>> def print_on_master(msg): + >>> print(msg) + """ + is_master = self.is_master(process_group) + + # define an inner functiuon + def decorator(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_master: + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py new file mode 100644 index 000000000..8c658e375 --- /dev/null +++ b/colossalai/interface/__init__.py @@ -0,0 +1,4 @@ +from .model import ModelWrapper +from .optimizer import OptimizerWrapper + +__all__ = ['OptimizerWrapper', 'ModelWrapper'] diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py new file mode 100644 index 000000000..a067d7671 --- /dev/null +++ b/colossalai/interface/model.py @@ -0,0 +1,25 @@ +import torch.nn as nn + + +class ModelWrapper(nn.Module): + """ + A wrapper class to define the common interface used by booster. + + Args: + module (nn.Module): The model to be wrapped. + """ + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def unwrap(self): + """ + Unwrap the model to return the original model for checkpoint saving/loading. + """ + if isinstance(self.module, ModelWrapper): + return self.module.unwrap() + return self.module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) diff --git a/colossalai/booster/interface/optimizer.py b/colossalai/interface/optimizer.py similarity index 100% rename from colossalai/booster/interface/optimizer.py rename to colossalai/interface/optimizer.py diff --git a/examples/tutorial/new_api/README.md b/examples/tutorial/new_api/README.md new file mode 100644 index 000000000..cec88f41c --- /dev/null +++ b/examples/tutorial/new_api/README.md @@ -0,0 +1,5 @@ +# New API Features + +**The New API is not officially released yet.** + +This folder contains some of the demonstrations of the new API. The new API is still under intensive development and will be released soon. diff --git a/examples/tutorial/new_api/test_ci.sh b/examples/tutorial/new_api/test_ci.sh new file mode 100644 index 000000000..8b4475e9f --- /dev/null +++ b/examples/tutorial/new_api/test_ci.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env +echo "The CI integration will be completed when the API is stable" diff --git a/examples/tutorial/new_api/torch_ddp/.gitignore b/examples/tutorial/new_api/torch_ddp/.gitignore new file mode 100644 index 000000000..a79cf5236 --- /dev/null +++ b/examples/tutorial/new_api/torch_ddp/.gitignore @@ -0,0 +1,4 @@ +data +checkpoint +ckpt-fp16 +ckpt-fp32 diff --git a/examples/tutorial/new_api/torch_ddp/README.md b/examples/tutorial/new_api/torch_ddp/README.md new file mode 100644 index 000000000..62d5a083d --- /dev/null +++ b/examples/tutorial/new_api/torch_ddp/README.md @@ -0,0 +1,44 @@ +# Distributed Data Parallel + +## 🚀 Quick Start + +This example provides a training script and and evaluation script. The training script provides a an example of training ResNet on CIFAR10 dataset from scratch. + +- Training Arguments + - `-r, `--resume`: resume from checkpoint file path + - `-c`, `--checkpoint`: the folder to save checkpoints + - `-i`, `--interval`: epoch interval to save checkpoints + - `-f`, `--fp16`: use fp16 + +- Eval Arguments + - `-e`, `--epoch`: select the epoch to evaluate + - `-c`, `--checkpoint`: the folder where checkpoints are found + + +### Train + +```bash +# train with torch DDP with fp32 +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32 + +# train with torch DDP with mixed precision training +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 --fp16 +``` + +### Eval + +```bash +# evaluate fp32 training +python eval.py -c ./ckpt-fp32 -e 80 + +# evaluate fp16 mixed precision training +python eval.py -c ./ckpt-fp16 -e 80 +``` + +Expected accuracy performance will be: + +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | +| --------- | ------------------------ | --------------------- | --------------------- | +| ResNet-18 | 85.85% | 85.03% | 85.12% | + +**Note: the baseline is a adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** diff --git a/examples/tutorial/new_api/torch_ddp/eval.py b/examples/tutorial/new_api/torch_ddp/eval.py new file mode 100644 index 000000000..657708ec3 --- /dev/null +++ b/examples/tutorial/new_api/torch_ddp/eval.py @@ -0,0 +1,48 @@ +import argparse + +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms + +# ============================== +# Parse Arguments +# ============================== +parser = argparse.ArgumentParser() +parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint") +parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") +args = parser.parse_args() + +# ============================== +# Prepare Test Dataset +# ============================== +# CIFAR-10 dataset +test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor()) + +# Data loader +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False) + +# ============================== +# Load Model +# ============================== +model = torchvision.models.resnet18(num_classes=10).cuda() +state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth') +model.load_state_dict(state_dict) + +# ============================== +# Run Evaluation +# ============================== +model.eval() + +with torch.no_grad(): + correct = 0 + total = 0 + for images, labels in test_loader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print('Accuracy of the model on the test images: {} %'.format(100 * correct / total)) diff --git a/examples/tutorial/new_api/torch_ddp/train.py b/examples/tutorial/new_api/torch_ddp/train.py new file mode 100644 index 000000000..4741c3151 --- /dev/null +++ b/examples/tutorial/new_api/torch_ddp/train.py @@ -0,0 +1,128 @@ +import argparse +from pathlib import Path + +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from torch.optim.lr_scheduler import MultiStepLR + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.cluster import DistCoordinator + +# ============================== +# Parse Arguments +# ============================== +parser = argparse.ArgumentParser() +parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") +parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") +parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") +parser.add_argument('-f', '--fp16', action='store_true', help="use fp16") +args = parser.parse_args() + +# ============================== +# Prepare Checkpoint Directory +# ============================== +Path(args.checkpoint).mkdir(parents=True, exist_ok=True) + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 80 +LEARNING_RATE = 1e-3 +START_EPOCH = args.resume if args.resume >= 0 else 0 + +# ============================== +# Launch Distributed Environment +# ============================== +colossalai.launch_from_torch(config={}) +coordinator = DistCoordinator() + +# update the learning rate with linear scaling +# old_gpu_num / old_lr = new_gpu_num / new_lr +LEARNING_RATE *= coordinator.world_size + +# ============================== +# Prepare Booster +# ============================== +plugin = TorchDDPPlugin() +if args.fp16: + booster = Booster(mixed_precision='fp16', plugin=plugin) +else: + booster = Booster(plugin=plugin) + +# ============================== +# Prepare Train Dataset +# ============================== +transform = transforms.Compose( + [transforms.Pad(4), + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32), + transforms.ToTensor()]) + +# CIFAR-10 dataset +with coordinator.priority_execution(): + train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=True) + +# ==================================== +# Prepare model, optimizer, criterion +# ==================================== +# resent50 +model = torchvision.models.resnet18(num_classes=10).cuda() + +# Loss and optimizer +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) + +# lr scheduler +lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3) + +# prepare dataloader with torch ddp plugin +train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=100, shuffle=True) + +# ============================== +# Resume from checkpoint +# ============================== +if args.resume >= 0: + booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') + booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') + booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + +# ============================== +# Boost with ColossalAI +# ============================== +model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, + train_dataloader, lr_scheduler) + +# ============================== +# Train model +# ============================== +total_step = len(train_dataloader) + +for epoch in range(START_EPOCH, NUM_EPOCHS): + for i, (images, labels) in enumerate(train_dataloader): + images = images.cuda() + labels = labels.cuda() + + # Forward pass + outputs = model(images) + loss = criterion(outputs, labels) + + # Backward and optimize + optimizer.zero_grad() + booster.backward(loss, optimizer) + optimizer.step() + + if (i + 1) % 100 == 0: + print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch + 1, NUM_EPOCHS, i + 1, total_step, + loss.item())) + + lr_scheduler.step() + + # save checkpoint every 5 epoch + if (epoch + 1) % args.interval == 0: + booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') + booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') + booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 58aef54c4..2dcc5a5bb 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -8,8 +8,8 @@ from torch.optim import SGD import colossalai from colossalai.booster import Booster -from colossalai.booster.interface import OptimizerWrapper from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.interface import OptimizerWrapper from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from tests.kit.model_zoo import model_zoo @@ -34,7 +34,7 @@ def check_torch_ddp_plugin(): model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - assert isinstance(model, DDP) + assert isinstance(model.module, DDP) assert isinstance(optimizer, OptimizerWrapper) output = model(**data) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 48376aaa8..f9f0e03c4 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -42,8 +42,8 @@ def test_unsharded_checkpoint(): new_optimizer = Adam(new_model.parameters(), lr=0.001) # load the model and optimizer - new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name) - new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + ckpt_io.load_model(new_model, model_ckpt_tempfile.name) + ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) # do recursive check for the optimizer state dict # if the value is a dict, compare its values