diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 4a42e2049..6e480d0db 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -97,10 +97,10 @@ class Booster: def boost( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: """ Boost the model, optimizer, criterion, lr_scheduler, and dataloader. diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 9999aa5e0..26fd92bd5 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -115,10 +115,12 @@ class FP16TorchMixedPrecision(MixedPrecision): def configure(self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable]: model = TorchAMPModule(model) - optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) + if optimizer is not None: + optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) if criterion is not None: criterion = TorchAMPModule(criterion) return model, optimizer, criterion diff --git a/colossalai/booster/mixed_precision/mixed_precision_base.py b/colossalai/booster/mixed_precision/mixed_precision_base.py index 2490e9811..8caa34e50 100644 --- a/colossalai/booster/mixed_precision/mixed_precision_base.py +++ b/colossalai/booster/mixed_precision/mixed_precision_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple import torch.nn as nn from torch.optim import Optimizer @@ -15,7 +15,8 @@ class MixedPrecision(ABC): @abstractmethod def configure(self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable]: # TODO: implement this method pass diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ce01ad111..60b25b2c4 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -274,11 +274,11 @@ class GeminiPlugin(DPPluginBase): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if not isinstance(model, ModelWrapper): # convert model to sync bn @@ -293,8 +293,12 @@ class GeminiPlugin(DPPluginBase): # wrap the model with Gemini model = GeminiModel(model, self.gemini_config, self.verbose) - if not isinstance(optimizer, OptimizerWrapper): - optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, + if optimizer is not None and \ + not isinstance(optimizer, OptimizerWrapper): + optimizer = GeminiOptimizer(model.unwrap(), + optimizer, + self.zero_optim_config, + self.optim_kwargs, self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 2b312d0f9..94d722080 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if not isinstance(model, ModelWrapper): model = LowLevelZeroModel(model, self.stage, self.precision) - if not isinstance(optimizer, OptimizerWrapper): - optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, + if optimizer is not None and \ + not isinstance(optimizer, OptimizerWrapper): + optimizer = LowLevelZeroOptimizer(model.unwrap(), + optimizer, + self.zero_optim_config, + self.optim_kwargs, self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index 561f58bc5..aa78f6827 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Iterator, List, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple, Union import torch.nn as nn from torch.optim import Optimizer @@ -38,11 +38,11 @@ class Plugin(ABC): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: # implement this method pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index a18073db6..4bfd61af3 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -138,11 +138,11 @@ class TorchDDPPlugin(DPPluginBase): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: # cast model to cuda model = model.cuda() @@ -152,7 +152,8 @@ class TorchDDPPlugin(DPPluginBase): # wrap the model with PyTorch DDP model = TorchDDPModel(model, **self.ddp_kwargs) - if not isinstance(optimizer, OptimizerWrapper): + if optimizer is not None and \ + not isinstance(optimizer, OptimizerWrapper): optimizer = OptimizerWrapper(optimizer) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index ebd03b6ea..abfffa9b0 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -195,23 +195,24 @@ class TorchFSDPPlugin(DPPluginBase): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: # wrap the model with PyTorch FSDP fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) - if len(optimizer.param_groups) > 1: - warnings.warn( - 'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.' - ) - optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) + if optimizer is not None: + if len(optimizer.param_groups) > 1: + warnings.warn( + 'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.' + ) + optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) - if not isinstance(optimizer, FSDPOptimizerWrapper): - optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model) + if not isinstance(optimizer, FSDPOptimizerWrapper): + optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model) return fsdp_model, optimizer, criterion, dataloader, lr_scheduler diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index ff0d4a1b5..fc9d8455e 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -4,12 +4,15 @@ import pytest import torch try: - from diffusers import UNet2DModel - MODELS = [UNet2DModel] + import diffusers + MODELS = [diffusers.UNet2DModel] HAS_REPO = True + from packaging import version + SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2") except: MODELS = [] HAS_REPO = False + SKIP_UNET_TEST = False from test_autochunk_diffuser_utils import run_test @@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]: return meta_args, concrete_args +@pytest.mark.skipif( + SKIP_UNET_TEST, + reason="diffusers version > 0.10.2", +) @pytest.mark.skipif( not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0",