mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 16:00:49 +00:00
[booster] implemented the torch ddd + resnet example (#3232)
* [booster] implemented the torch ddd + resnet example * polish code
This commit is contained in:
@@ -8,6 +8,8 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
from .plugin import Plugin
|
||||
@@ -61,19 +63,21 @@ class Booster:
|
||||
self.plugin = plugin
|
||||
|
||||
# set accelerator
|
||||
if self.plugin and self.plugin.control_device:
|
||||
if self.plugin and self.plugin.control_device():
|
||||
self.accelerator = None
|
||||
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
|
||||
else:
|
||||
self.accelerator = Accelerator(device)
|
||||
|
||||
# set precision
|
||||
if mixed_precision is None or (self.plugin and self.plugin.control_precision):
|
||||
self.mixed_precision = None
|
||||
if self.plugin and self.plugin.control_precision():
|
||||
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
|
||||
self.mixed_precision = None
|
||||
elif mixed_precision is None:
|
||||
self.mixed_precision = None
|
||||
else:
|
||||
# validate and set precision
|
||||
if isinstance(MixedPrecision, str):
|
||||
if isinstance(mixed_precision, str):
|
||||
# the user will take the default arguments for amp training
|
||||
self.mixed_precision = mixed_precision_factory(mixed_precision)
|
||||
elif isinstance(mixed_precision, MixedPrecision):
|
||||
@@ -84,6 +88,11 @@ class Booster:
|
||||
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
|
||||
)
|
||||
|
||||
if self.plugin is not None and self.plugin.control_checkpoint_io():
|
||||
self.checkpoint_io = self.plugin.get_checkpoint_io()
|
||||
else:
|
||||
self.checkpoint_io = GeneralCheckpointIO()
|
||||
|
||||
def boost(
|
||||
self,
|
||||
model: nn.Module,
|
||||
@@ -109,12 +118,13 @@ class Booster:
|
||||
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
|
||||
model, optimizer, criterion, dataloader, lr_scheduler)
|
||||
|
||||
if self.plugin and not self.plugin.control_device:
|
||||
if self.plugin and not self.plugin.control_device():
|
||||
# transform model for accelerator
|
||||
model = self.accelerator.configure(model)
|
||||
|
||||
if self.mixed_precision and self.plugin and not self.plugin.control_precision:
|
||||
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
|
||||
# transform model for mixed precision
|
||||
# when mixed_precision is specified and the plugin is not given or does not control the precision
|
||||
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
@@ -140,18 +150,25 @@ class Booster:
|
||||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model)
|
||||
|
||||
def save(self,
|
||||
obj: Union[nn.Module, Optimizer, LRScheduler],
|
||||
path_like: str,
|
||||
plan: str = 'torch',
|
||||
**kwargs) -> None:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
||||
|
||||
def load(self,
|
||||
obj: Union[nn.Module, Optimizer, LRScheduler],
|
||||
path_like: str,
|
||||
plan: str = 'torch',
|
||||
**kwargs) -> None:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
prefix: str = None,
|
||||
shard: bool = False,
|
||||
size_per_shard: int = 1024):
|
||||
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
Reference in New Issue
Block a user