[booster] make optimizer argument optional for boost (#3993)

* feat: make optimizer optional in Booster.boost

* test: skip unet test if diffusers version > 0.10.2
This commit is contained in:
Wenhao Chen
2023-06-15 17:38:42 +08:00
committed by GitHub
parent c9cff7e7fa
commit 725af3eeeb
9 changed files with 70 additions and 50 deletions

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
@@ -15,7 +15,8 @@ class MixedPrecision(ABC):
@abstractmethod
def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method
pass