mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-04 09:40:11 +00:00
polish unitest test with titans (#1152)
This commit is contained in:
parent
f1f51990b9
commit
ff644ee5e4
@ -16,12 +16,10 @@ from colossalai.core import global_context as gpc
|
|||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn import CrossEntropyLoss
|
from colossalai.nn import CrossEntropyLoss
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.utils import is_using_pp, get_dataloader
|
from colossalai.utils import get_dataloader
|
||||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||||
from tqdm import tqdm
|
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from titans.model.vit import vit_tiny_patch4_32
|
|
||||||
|
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
NUM_EPOCHS = 60
|
NUM_EPOCHS = 60
|
||||||
@ -41,6 +39,12 @@ def run_trainer(rank, world_size, port):
|
|||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
pipelinable = PipelinableContext()
|
pipelinable = PipelinableContext()
|
||||||
|
try:
|
||||||
|
from titans.model.vit import vit_tiny_patch4_32
|
||||||
|
except ImportError:
|
||||||
|
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
|
||||||
|
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
|
||||||
|
return
|
||||||
with pipelinable:
|
with pipelinable:
|
||||||
model = vit_tiny_patch4_32()
|
model = vit_tiny_patch4_32()
|
||||||
pipelinable.to_layer_list()
|
pipelinable.to_layer_list()
|
||||||
|
Loading…
Reference in New Issue
Block a user