[unittest] refactored unit tests for change in dependency (#838)

This commit is contained in:
Frank Lee
2022-04-22 15:39:07 +08:00
committed by GitHub
parent f271f34716
commit 943982d29a
3 changed files with 25 additions and 58 deletions

View File

@@ -9,34 +9,21 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import colossalai
from colossalai.builder import build_dataset
from torchvision import transforms
from torchvision import transforms, datasets
from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader, free_port
from colossalai.testing import rerun_if_address_is_in_use
from torchvision.transforms import ToTensor
CONFIG = Config(
dict(
train_data=dict(
dataset=dict(
type='CIFAR10',
root=Path(os.environ['DATA']),
train=True,
download=True,
),
dataloader=dict(batch_size=8,),
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None),
),
seed=1024,
))
CONFIG = Config(dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None),
),
seed=1024,
))
def run_data_sampler(rank, world_size, port):
@@ -44,11 +31,14 @@ def run_data_sampler(rank, world_size, port):
colossalai.launch(**dist_args)
print('finished initialization')
transform_pipeline = [ToTensor()]
# build dataset
transform_pipeline = [transforms.ToTensor()]
transform_pipeline = transforms.Compose(transform_pipeline)
gpc.config.train_data.dataset['transform'] = transform_pipeline
dataset = build_dataset(gpc.config.train_data.dataset)
dataloader = get_dataloader(dataset, **gpc.config.train_data.dataloader)
dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)
# build dataloader
dataloader = get_dataloader(dataset, batch_size=8, add_sampler=True)
data_iter = iter(dataloader)
img, label = data_iter.next()
img = img[0]