mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[booster] make optimizer argument optional for boost (#3993)
* feat: make optimizer optional in Booster.boost * test: skip unet test if diffusers version > 0.10.2
This commit is contained in:
@@ -115,10 +115,12 @@ class FP16TorchMixedPrecision(MixedPrecision):
|
||||
|
||||
def configure(self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
model = TorchAMPModule(model)
|
||||
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||
if optimizer is not None:
|
||||
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||
if criterion is not None:
|
||||
criterion = TorchAMPModule(criterion)
|
||||
return model, optimizer, criterion
|
||||
|
Reference in New Issue
Block a user