diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index 8e09b6cb2..aa45bcb59 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -1,5 +1,6 @@ from .gemini_plugin import GeminiPlugin +from .low_level_zero_plugin import LowLevelZeroPlugin from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin'] +__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin'] diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py new file mode 100644 index 000000000..969c430bd --- /dev/null +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -0,0 +1,259 @@ +import random +import warnings +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.checkpoint_io import CheckpointIO +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.utils import get_current_device +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper + +from .plugin_base import Plugin +from .torch_ddp_plugin import TorchDDPCheckpointIO + +__all__ = ['LowLevelZeroPlugin'] + + +def _convert_to_fp16(x): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.half() + return x + + +class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer to checkpoint but only on master process. + """ + # TODO(ver217): optimizer state dict is sharded + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + + +class LowLevelZeroModel(ModelWrapper): + + def __init__(self, module: nn.Module, stage: int, precision: str) -> None: + super().__init__(module) + self.convert_inputs = (precision == 'fp16') + module = zero_model_wrapper(module, zero_stage=stage) + if precision == 'fp16': + module = module.half() + module = module.to(get_current_device()) + self.module = module + + def forward(self, *args, **kwargs): + if self.convert_inputs: + args = tree_map(_convert_to_fp16, args) + kwargs = tree_map(_convert_to_fp16, kwargs) + return super().forward(*args, **kwargs) + + +class LowLevelZeroOptimizer(OptimizerWrapper): + + def __init__(self, + module: nn.Module, + optimizer: Optimizer, + zero_optim_config: dict, + optim_kwargs: dict, + verbose: bool = False) -> None: + optimizer = zero_optim_wrapper(module, + optimizer, + optim_config=zero_optim_config, + **optim_kwargs, + verbose=verbose) + super().__init__(optimizer) + + def backward(self, loss: Tensor, *args, **kwargs): + self.optim.backward(loss) + + def clip_grad_by_norm(self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2, + error_if_nonfinite: bool = False, + *args, + **kwargs) -> Tensor: + warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm') + + def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: + raise NotImplementedError('LowLevelZero does not support clip_grad_by_value') + + +class LowLevelZeroPlugin(Plugin): + """ + Plugin for low level zero. + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import LowLevelZeroPlugin + >>> + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = LowLevelZeroPlugin() + + >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + + Args: + strage (int, optional): ZeRO stage. Defaults to 1. + precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do + clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. + norm_type (float, optional): norm_type used for `clip_grad_norm`. + reduce_bucket_size_in_m (int, optional): grad reduce bucket size in M. Defaults to 12. + communication_dtype (torch.dtype, optional): communication dtype. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True. + cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False. + verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. + """ + + def __init__( + self, + stage: int = 1, + precision: str = 'fp16', + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + reduce_bucket_size_in_m: int = 12, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + cpu_offload: bool = False, + verbose: bool = False, + ) -> None: + + assert dist.is_initialized( + ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' + assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training' + + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + + self.stage = stage + self.precision = precision + self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload) + self.optim_kwargs = dict(initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) + self.verbose = verbose + + def support_no_sync(self) -> bool: + return False + + def control_precision(self) -> bool: + return True + + def supported_precisions(self) -> List[str]: + return ['fp16', 'fp32'] + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ['cuda'] + + def prepare_train_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + Note: + 1. Evaluation datasets should not be passed to this function. + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + + if not isinstance(model, ModelWrapper): + model = LowLevelZeroModel(model, self.stage, self.precision) + + if not isinstance(optimizer, OptimizerWrapper): + optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, + self.verbose) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return LowLevelZeroCheckpointIO() diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 39ade27b9..59c99113e 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -55,6 +55,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # 2. contiguous gradients # 3. cpu offload # 4. support when some parameters requires_grad = False + # 5. support layer drop super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype self._logger = get_dist_logger() diff --git a/tests/kit/model_zoo/diffusers/diffusers.py b/tests/kit/model_zoo/diffusers/diffusers.py index 8aa3f4c67..204c1d777 100644 --- a/tests/kit/model_zoo/diffusers/diffusers.py +++ b/tests/kit/model_zoo/diffusers/diffusers.py @@ -18,6 +18,7 @@ data_vae_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32)) data_unet_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32), timestep=3) identity_output = lambda x: x +clip_vision_model_output = lambda x: dict(pooler_output=x[1]) def data_clip_model(): @@ -65,7 +66,7 @@ model_zoo.register(name='diffusers_clip_text_model', model_zoo.register(name='diffusers_clip_vision_model', model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()), data_gen_fn=data_clip_vision, - output_transform_fn=identity_output) + output_transform_fn=clip_vision_model_output) model_zoo.register(name='diffusers_unet2d_model', model_fn=diffusers.UNet2DModel, diff --git a/tests/kit/model_zoo/torchaudio/torchaudio.py b/tests/kit/model_zoo/torchaudio/torchaudio.py index 746117202..9a244ac31 100644 --- a/tests/kit/model_zoo/torchaudio/torchaudio.py +++ b/tests/kit/model_zoo/torchaudio/torchaudio.py @@ -1,3 +1,5 @@ +from functools import partial + import torch import torchaudio.models as tm @@ -101,13 +103,11 @@ def tacotron_data_gen_fn(): mel_specgram_lengths=mel_specgram_lengths) -model_zoo.register( - name='torchaudio_tacotron', - model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), - data_gen_fn=tacotron_data_gen_fn, - output_transform_fn=lambda outputs: dict( - spectrogram_before=outputs[0], spectrogram_after=outputs[1], stop_tokens=outputs[2], attn_weights=outputs[3]), - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='torchaudio_tacotron', + model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), + data_gen_fn=tacotron_data_gen_fn, + output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)), + model_attribute=ModelAttribute(has_control_flow=True)) def wav2vec_data_gen_fn(): @@ -118,7 +118,7 @@ def wav2vec_data_gen_fn(): model_zoo.register(name='torchaudio_wav2vec2_base', - model_fn=tm.wav2vec2_base, + model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0), data_gen_fn=wav2vec_data_gen_fn, output_transform_fn=transformer_output_transform_fn, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/torchvision/torchvision.py b/tests/kit/model_zoo/torchvision/torchvision.py index 62bda93d5..ddc3ec24b 100644 --- a/tests/kit/model_zoo/torchvision/torchvision.py +++ b/tests/kit/model_zoo/torchvision/torchvision.py @@ -36,12 +36,12 @@ def swin_s(): # special output transform fn -google_net_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.GoogLeNetOutputs - ) else dict(output=x) +google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs + ) else dict(output=x) swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) -inception_v3_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.InceptionOutputs - ) else dict(output=x) +inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs + ) else dict(output=x) model_zoo.register(name='torchvision_alexnet', model_fn=tm.alexnet, diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index d804c727a..985d7989f 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -1,4 +1,5 @@ from contextlib import nullcontext +from typing import Optional import torch import torch.distributed as dist @@ -10,11 +11,53 @@ from colossalai.fx import is_compatible_with_meta from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.model.experimental import LazyInitContext from colossalai.zero import ColoInitContext from tests.kit.model_zoo import model_zoo -@parameterize('init_method', ['lazy', 'none', 'colo']) +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: + try: + if init_method == 'colo': + ctx = ColoInitContext() + elif init_method == 'lazy': + ctx = LazyInitContext() + else: + ctx = nullcontext() + plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) + booster = Booster(plugin=plugin) + with ctx: + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + for n, p in model.named_parameters(): + assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter' + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + except Exception as e: + return repr(e) + + +# TODO(ver217): CI does not support lazy now +# @parameterize('init_method', ['lazy', 'none', 'colo']) + + +@parameterize('init_method', ['none']) def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): """check gemini plugin over model zoo @@ -25,7 +68,6 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): if not is_support_meta and init_method == 'lazy': return - from colossalai.utils.model.experimental import LazyInitContext passed_models = [] failed_info = {} # (model_name, error) pair @@ -58,48 +100,16 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): ]: continue - try: - if init_method == 'colo': - ctx = ColoInitContext() - elif init_method == 'lazy': - ctx = LazyInitContext() - else: - ctx = nullcontext() - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) - booster = Booster(plugin=plugin) - with ctx: - model = model_fn() - optimizer = HybridAdam(model.parameters(), lr=1e-3) - criterion = lambda x: x.mean() - data = data_gen_fn() - - data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v - for k, v in data.items() - } - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - for n, p in model.named_parameters(): - assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter' - - output = model(**data) - output = output_transform_fn(output) - output_key = list(output.keys())[0] - loss = criterion(output[output_key]) - - booster.backward(loss, optimizer) - optimizer.step() - passed_models.append(name) - - del booster, plugin, model, optimizer, criterion, data, output, loss - except Exception as e: - failed_info[name] = e - if early_stop: - raise e - + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + if dist.get_rank() == 0: print(f'Init method: {init_method}') print(f'Passed models({len(passed_models)}): {passed_models}\n\n') @@ -140,7 +150,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() def test_gemini_plugin(early_stop: bool = True): - spawn(run_dist, 2, early_stop=early_stop) + spawn(run_dist, 4, early_stop=early_stop) if __name__ == '__main__': diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py new file mode 100644 index 000000000..e24196a14 --- /dev/null +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -0,0 +1,122 @@ +from typing import Optional + +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + +# These models are not compatible with AMP +_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`'] +# These models have no parameters +_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] +# These models will get stuck +_STUCK_MODELS = [ + 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', + 'transformers_bert_for_pretraining', 'transformers_gpt_double_heads' +] + + +def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: + try: + plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) + booster = Booster(plugin=plugin) + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + except Exception as e: + return repr(e) + + +@parameterize('stage', [2]) +def check_low_level_zero_plugin(stage: int, early_stop: bool = True): + """check low level zero plugin over model zoo + + Args: + stage (int), stage of low level zero plugin + early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. + """ + passed_models = [] + failed_info = {} # (model_name, error) pair + ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS + skipped_models = [] + + for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + # FIXME(ver217): fix these models + if name in ignore_models: + skipped_models.append(name) + continue + err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() + + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + + if dist.get_rank() == 0: + print(f'Passed models({len(passed_models)}): {passed_models}\n\n') + print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') + print(f'Skipped models({len(skipped_models)}): {skipped_models}\n\n') + assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + + +def check_dataloader_sharding(): + plugin = LowLevelZeroPlugin() + + # create a custom dasetset with 0 to 10 + dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) + train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) + + # get the first batch of data + batch = next(iter(train_dataloader))[0].cuda() + is_rank_0 = dist.get_rank() == 0 + + if is_rank_0: + batch_to_compare = batch.clone() + else: + batch_to_compare = batch + # pass to the rank 1 value to rank 0 + dist.broadcast(batch_to_compare, src=1) + + # compare on rank 0 + if is_rank_0: + assert not torch.equal(batch, + batch_to_compare), 'Same number was found across ranks but expected it to be different' + + +def run_dist(rank, world_size, port, early_stop: bool = True): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_low_level_zero_plugin(early_stop=early_stop) + + +@rerun_if_address_is_in_use() +def test_low_level_zero_plugin(early_stop: bool = True): + spawn(run_dist, 2, early_stop=early_stop) + + +if __name__ == '__main__': + test_low_level_zero_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index c225a1a06..5354eae01 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -11,36 +11,37 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -def check_torch_ddp_plugin(): +def run_fn(model_fn, data_gen_fn, output_transform_fn): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) + model = model_fn() + optimizer = SGD(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + assert isinstance(model.module, DDP) + assert isinstance(optimizer, OptimizerWrapper) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + + +def check_torch_ddp_plugin(): for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): if name == 'dlrm_interactionarch': continue - - model = model_fn() - optimizer = SGD(model.parameters(), lr=1e-3) - criterion = lambda x: x.mean() - data = data_gen_fn() - - data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() - } - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - assert isinstance(model.module, DDP) - assert isinstance(optimizer, OptimizerWrapper) - - output = model(**data) - output = output_transform_fn(output) - output_key = list(output.keys())[0] - loss = criterion(output[output_key]) - - booster.backward(loss, optimizer) - optimizer.clip_grad_by_norm(1.0) - optimizer.step() + run_fn(model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() def check_dataloader_sharding():