[booster] implemented the torch ddd + resnet example (#3232)

* [booster] implemented the torch ddd + resnet example

* polish code
This commit is contained in:
Frank Lee
2023-03-27 10:24:14 +08:00
committed by GitHub
parent 1a229045af
commit 73d3e4d309
22 changed files with 608 additions and 128 deletions

View File

@@ -8,6 +8,8 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import GeneralCheckpointIO
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
@@ -61,19 +63,21 @@ class Booster:
self.plugin = plugin
# set accelerator
if self.plugin and self.plugin.control_device:
if self.plugin and self.plugin.control_device():
self.accelerator = None
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
else:
self.accelerator = Accelerator(device)
# set precision
if mixed_precision is None or (self.plugin and self.plugin.control_precision):
self.mixed_precision = None
if self.plugin and self.plugin.control_precision():
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
else:
# validate and set precision
if isinstance(MixedPrecision, str):
if isinstance(mixed_precision, str):
# the user will take the default arguments for amp training
self.mixed_precision = mixed_precision_factory(mixed_precision)
elif isinstance(mixed_precision, MixedPrecision):
@@ -84,6 +88,11 @@ class Booster:
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
)
if self.plugin is not None and self.plugin.control_checkpoint_io():
self.checkpoint_io = self.plugin.get_checkpoint_io()
else:
self.checkpoint_io = GeneralCheckpointIO()
def boost(
self,
model: nn.Module,
@@ -109,12 +118,13 @@ class Booster:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
model, optimizer, criterion, dataloader, lr_scheduler)
if self.plugin and not self.plugin.control_device:
if self.plugin and not self.plugin.control_device():
# transform model for accelerator
model = self.accelerator.configure(model)
if self.mixed_precision and self.plugin and not self.plugin.control_precision:
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
# transform model for mixed precision
# when mixed_precision is specified and the plugin is not given or does not control the precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
return model, optimizer, criterion, dataloader, lr_scheduler
@@ -140,18 +150,25 @@ class Booster:
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)
def save(self,
obj: Union[nn.Module, Optimizer, LRScheduler],
path_like: str,
plan: str = 'torch',
**kwargs) -> None:
# TODO: implement this method
pass
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
self.checkpoint_io.load_model(model, checkpoint, strict)
def load(self,
obj: Union[nn.Module, Optimizer, LRScheduler],
path_like: str,
plan: str = 'torch',
**kwargs) -> None:
# TODO: implement this method
pass
def save_model(self,
model: nn.Module,
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)