[booster] make optimizer argument optional for boost (#3993)

* feat: make optimizer optional in Booster.boost

* test: skip unet test if diffusers version > 0.10.2
This commit is contained in:
Wenhao Chen
2023-06-15 17:38:42 +08:00
committed by GitHub
parent c9cff7e7fa
commit 725af3eeeb
9 changed files with 70 additions and 50 deletions

View File

@@ -274,11 +274,11 @@ class GeminiPlugin(DPPluginBase):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
# convert model to sync bn
@@ -293,8 +293,12 @@ class GeminiPlugin(DPPluginBase):
# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler