[example] updated the hybrid parallel tutorial (#2444)

* [example] updated the hybrid parallel tutorial

* polish code
This commit is contained in:
Frank Lee
2023-01-11 15:17:17 +08:00
committed by GitHub
parent 5521af7877
commit 39163417a1
6 changed files with 82 additions and 65 deletions

View File

@@ -1,7 +1,6 @@
import os
import torch
from titans.dataloader.cifar10 import build_cifar
from titans.model.vit.vit import _create_vit_model
from tqdm import tqdm
@@ -12,7 +11,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.utils import get_dataloader, is_using_pp
from colossalai.utils import is_using_pp
class DummyDataloader():
@@ -42,12 +41,9 @@ class DummyDataloader():
def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
args = parser.parse_args()
# launch from torch
parser = colossalai.get_default_parser()
args = parser.parse_args()
colossalai.launch_from_torch(config=args.config)
# get logger
@@ -94,15 +90,10 @@ def main():
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
# create dataloaders
root = os.environ.get('DATA', '../data')
if args.synthetic:
# if we use synthetic dataset
# we train for 10 steps and eval for 5 steps per epoch
train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
else:
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
# use synthetic dataset
# we train for 10 steps and eval for 5 steps per epoch
train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
@@ -139,6 +130,7 @@ def main():
engine.execute_schedule(data_iter, return_output_label=False)
engine.step()
lr_scheduler.step()
gpc.destroy()
if __name__ == '__main__':