mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[booster] make optimizer argument optional for boost (#3993)
* feat: make optimizer optional in Booster.boost * test: skip unet test if diffusers version > 0.10.2
This commit is contained in:
@@ -4,12 +4,15 @@ import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
from diffusers import UNet2DModel
|
||||
MODELS = [UNet2DModel]
|
||||
import diffusers
|
||||
MODELS = [diffusers.UNet2DModel]
|
||||
HAS_REPO = True
|
||||
from packaging import version
|
||||
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
|
||||
except:
|
||||
MODELS = []
|
||||
HAS_REPO = False
|
||||
SKIP_UNET_TEST = False
|
||||
|
||||
from test_autochunk_diffuser_utils import run_test
|
||||
|
||||
@@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
||||
return meta_args, concrete_args
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
SKIP_UNET_TEST,
|
||||
reason="diffusers version > 0.10.2",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
|
Reference in New Issue
Block a user