mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-02 08:16:48 +00:00
[tutorial] added synthetic dataset for auto parallel demo (#1918)
This commit is contained in:
parent
acd9abc5ca
commit
1b0dd05940
@ -1,3 +1,4 @@
|
|||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -29,11 +30,25 @@ BATCH_SIZE = 1024
|
|||||||
NUM_EPOCHS = 10
|
NUM_EPOCHS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize_data():
|
||||||
|
img = torch.rand(BATCH_SIZE, 3, 32, 32)
|
||||||
|
label = torch.randint(low=0, high=10, size=(BATCH_SIZE,))
|
||||||
|
return img, label
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
args = parse_args()
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch(config={})
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
if not args.synthetic:
|
||||||
with barrier_context():
|
with barrier_context():
|
||||||
# build dataloaders
|
# build dataloaders
|
||||||
train_dataset = CIFAR10(root=DATA_ROOT,
|
train_dataset = CIFAR10(root=DATA_ROOT,
|
||||||
@ -42,7 +57,8 @@ def main():
|
|||||||
transforms.RandomCrop(size=32, padding=4),
|
transforms.RandomCrop(size=32, padding=4),
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.RandomHorizontalFlip(),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
|
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
|
||||||
|
std=[0.2023, 0.1994, 0.2010]),
|
||||||
]))
|
]))
|
||||||
|
|
||||||
test_dataset = CIFAR10(root=DATA_ROOT,
|
test_dataset = CIFAR10(root=DATA_ROOT,
|
||||||
@ -66,6 +82,8 @@ def main():
|
|||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
train_dataloader, test_dataloader = None, None
|
||||||
|
|
||||||
# initialize device mesh
|
# initialize device mesh
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
@ -112,11 +130,26 @@ def main():
|
|||||||
|
|
||||||
for epoch in range(NUM_EPOCHS):
|
for epoch in range(NUM_EPOCHS):
|
||||||
gm.train()
|
gm.train()
|
||||||
if gpc.get_global_rank() == 0:
|
|
||||||
train_dl = tqdm(train_dataloader)
|
if args.synthetic:
|
||||||
|
# if we use synthetic data
|
||||||
|
# we assume it only has 30 steps per epoch
|
||||||
|
num_steps = range(30)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
train_dl = train_dataloader
|
# we use the actual number of steps for training
|
||||||
for img, label in train_dl:
|
num_steps = range(len(train_dataloader))
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
progress = tqdm(num_steps)
|
||||||
|
|
||||||
|
for _ in progress:
|
||||||
|
if args.synthetic:
|
||||||
|
# generate fake data
|
||||||
|
img, label = synthesize_data()
|
||||||
|
else:
|
||||||
|
# get the real data
|
||||||
|
img, label = next(data_iter)
|
||||||
|
|
||||||
img = img.cuda()
|
img = img.cuda()
|
||||||
label = label.cuda()
|
label = label.cuda()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -126,10 +159,30 @@ def main():
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
# run evaluation
|
||||||
gm.eval()
|
gm.eval()
|
||||||
correct = 0
|
correct = 0
|
||||||
total = 0
|
total = 0
|
||||||
for img, label in test_dataloader:
|
|
||||||
|
if args.synthetic:
|
||||||
|
# if we use synthetic data
|
||||||
|
# we assume it only has 10 steps for evaluation
|
||||||
|
num_steps = range(30)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# we use the actual number of steps for training
|
||||||
|
num_steps = range(len(test_dataloader))
|
||||||
|
data_iter = iter(test_dataloader)
|
||||||
|
progress = tqdm(num_steps)
|
||||||
|
|
||||||
|
for _ in progress:
|
||||||
|
if args.synthetic:
|
||||||
|
# generate fake data
|
||||||
|
img, label = synthesize_data()
|
||||||
|
else:
|
||||||
|
# get the real data
|
||||||
|
img, label = next(data_iter)
|
||||||
|
|
||||||
img = img.cuda()
|
img = img.cuda()
|
||||||
label = label.cuda()
|
label = label.cuda()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user