mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[booster] make optimizer argument optional for boost (#3993)
* feat: make optimizer optional in Booster.boost * test: skip unet test if diffusers version > 0.10.2
This commit is contained in:
@@ -274,11 +274,11 @@ class GeminiPlugin(DPPluginBase):
|
||||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
# convert model to sync bn
|
||||
@@ -293,8 +293,12 @@ class GeminiPlugin(DPPluginBase):
|
||||
# wrap the model with Gemini
|
||||
model = GeminiModel(model, self.gemini_config, self.verbose)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(model.unwrap(),
|
||||
optimizer,
|
||||
self.zero_optim_config,
|
||||
self.optim_kwargs,
|
||||
self.verbose)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
Reference in New Issue
Block a user