mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-29 04:40:36 +00:00
[example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial * polish code
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from titans.dataloader.cifar10 import build_cifar
|
||||
from titans.model.vit.vit import _create_vit_model
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -12,7 +11,7 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from colossalai.utils import get_dataloader, is_using_pp
|
||||
from colossalai.utils import is_using_pp
|
||||
|
||||
|
||||
class DummyDataloader():
|
||||
@@ -42,12 +41,9 @@ class DummyDataloader():
|
||||
|
||||
|
||||
def main():
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from torch
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
@@ -94,15 +90,10 @@ def main():
|
||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
||||
|
||||
# create dataloaders
|
||||
root = os.environ.get('DATA', '../data')
|
||||
if args.synthetic:
|
||||
# if we use synthetic dataset
|
||||
# we train for 10 steps and eval for 5 steps per epoch
|
||||
train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
|
||||
test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
|
||||
else:
|
||||
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
||||
# use synthetic dataset
|
||||
# we train for 10 steps and eval for 5 steps per epoch
|
||||
train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
|
||||
test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
@@ -139,6 +130,7 @@ def main():
|
||||
engine.execute_schedule(data_iter, return_output_label=False)
|
||||
engine.step()
|
||||
lr_scheduler.step()
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user