mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[booster] make optimizer argument optional for boost (#3993)
* feat: make optimizer optional in Booster.boost * test: skip unet test if diffusers version > 0.10.2
This commit is contained in:
@@ -195,23 +195,24 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
# wrap the model with PyTorch FSDP
|
||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||
|
||||
if len(optimizer.param_groups) > 1:
|
||||
warnings.warn(
|
||||
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
||||
)
|
||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||
if optimizer is not None:
|
||||
if len(optimizer.param_groups) > 1:
|
||||
warnings.warn(
|
||||
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
||||
)
|
||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||
|
||||
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
||||
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
|
||||
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
||||
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
|
||||
|
||||
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
|
Reference in New Issue
Block a user