[tutorial] added synthetic data for hybrid parallel (#1919)

This commit is contained in:
Frank Lee 2022-11-12 17:49:48 +08:00 committed by GitHub
parent 1b0dd05940
commit 3c42fdbedc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 153 additions and 119 deletions

View File

@ -2,16 +2,23 @@
## Prepare Dataset ## Prepare Dataset
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default. We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command. If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
```bash ```bash
export DATA=/path/to/data export DATA=/path/to/data
``` ```
You can also use synthetic data for this tutorial if you don't wish to download the `CIFAR10` dataset by adding the `-s` or `--synthetic` flag to the command.
## Run on 2*2 device mesh ## Run on 2*2 device mesh
```bash ```bash
# run with cifar10
colossalai run --nproc_per_node 4 train.py --config config.py colossalai run --nproc_per_node 4 train.py --config config.py
```
# run with synthetic dataset
colossalai run --nproc_per_node 4 train.py --config config.py -s
```

View File

@ -1,117 +1,144 @@
import os import os
import colossalai
import torch import torch
from titans.dataloader.cifar10 import build_cifar
from tqdm import tqdm from titans.model.vit.vit import _create_vit_model
from colossalai.context import ParallelMode from tqdm import tqdm
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger import colossalai
from colossalai.nn import CrossEntropyLoss from colossalai.context import ParallelMode
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import Lars, Lamb from colossalai.logging import get_dist_logger
from colossalai.utils import is_using_pp, get_dataloader from colossalai.nn import CrossEntropyLoss
from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from titans.model.vit.vit import _create_vit_model from colossalai.nn.optimizer import Lamb, Lars
from titans.dataloader.cifar10 import build_cifar from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.utils import get_dataloader, is_using_pp
def main():
# initialize distributed setting class DummyDataloader():
parser = colossalai.get_default_parser()
args = parser.parse_args() def __init__(self, length, batch_size):
self.length = length
# launch from torch self.batch_size = batch_size
colossalai.launch_from_torch(config=args.config)
def generate(self):
# get logger data = torch.rand(self.batch_size, 3, 224, 224)
logger = get_dist_logger() label = torch.randint(low=0, high=10, size=(self.batch_size,))
logger.info("initialized distributed environment", ranks=[0]) return data, label
if hasattr(gpc.config, 'LOG_PATH'): def __iter__(self):
if gpc.get_global_rank() == 0: self.step = 0
log_path = gpc.config.LOG_PATH return self
if not os.path.exists(log_path):
os.mkdir(log_path) def __next__(self):
logger.log_to_file(log_path) if self.step < self.length:
self.step += 1
use_pipeline = is_using_pp() return self.generate()
else:
# create model raise StopIteration
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
patch_size=gpc.config.PATCH_SIZE, def __len__(self):
hidden_size=gpc.config.HIDDEN_SIZE, return self.length
depth=gpc.config.DEPTH,
num_heads=gpc.config.NUM_HEADS,
mlp_ratio=gpc.config.MLP_RATIO, def main():
num_classes=10, # initialize distributed setting
init_method='jax', parser = colossalai.get_default_parser()
checkpoint=gpc.config.CHECKPOINT) parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
args = parser.parse_args()
if use_pipeline:
pipelinable = PipelinableContext() # launch from torch
with pipelinable: colossalai.launch_from_torch(config=args.config)
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list() # get logger
pipelinable.policy = "uniform" logger = get_dist_logger()
model = pipelinable.partition( logger.info("initialized distributed environment", ranks=[0])
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else: if hasattr(gpc.config, 'LOG_PATH'):
model = _create_vit_model(**model_kwargs) if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
# count number of parameters if not os.path.exists(log_path):
total_numel = 0 os.mkdir(log_path)
for p in model.parameters(): logger.log_to_file(log_path)
total_numel += p.numel()
if not gpc.is_initialized(ParallelMode.PIPELINE): use_pipeline = is_using_pp()
pipeline_stage = 0
else: # create model
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
logger.info( patch_size=gpc.config.PATCH_SIZE,
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") hidden_size=gpc.config.HIDDEN_SIZE,
depth=gpc.config.DEPTH,
# create dataloaders num_heads=gpc.config.NUM_HEADS,
root = os.environ.get('DATA', '../data/cifar10') mlp_ratio=gpc.config.MLP_RATIO,
train_dataloader, test_dataloader = build_cifar( num_classes=10,
gpc.config.BATCH_SIZE, root, pad_if_needed=True) init_method='jax',
checkpoint=gpc.config.CHECKPOINT)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1) if use_pipeline:
pipelinable = PipelinableContext()
# create optimizer with pipelinable:
optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE, model = _create_vit_model(**model_kwargs)
weight_decay=gpc.config.WEIGHT_DECAY) pipelinable.to_layer_list()
pipelinable.policy = "uniform"
# create lr scheduler model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, else:
total_steps=gpc.config.NUM_EPOCHS, model = _create_vit_model(**model_kwargs)
warmup_steps=gpc.config.WARMUP_EPOCHS)
# count number of parameters
# initialize total_numel = 0
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, for p in model.parameters():
optimizer=optimizer, total_numel += p.numel()
criterion=criterion, if not gpc.is_initialized(ParallelMode.PIPELINE):
train_dataloader=train_dataloader, pipeline_stage = 0
test_dataloader=test_dataloader) else:
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info("Engine is built", ranks=[0]) logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
data_iter = iter(train_dataloader) # create dataloaders
root = os.environ.get('DATA', '../data/')
for epoch in range(gpc.config.NUM_EPOCHS): if args.synthetic:
# training train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE)
engine.train() test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
else:
if gpc.get_global_rank() == 0: train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
progress = tqdm(range(len(train_dataloader)), desc=description) # create loss function
else: criterion = CrossEntropyLoss(label_smoothing=0.1)
progress = range(len(train_dataloader))
for _ in progress: # create optimizer
engine.zero_grad() optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
engine.execute_schedule(data_iter, return_output_label=False)
engine.step() # create lr scheduler
lr_scheduler.step() lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
if __name__ == '__main__':
main() # initialize
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
logger.info("Engine is built", ranks=[0])
for epoch in range(gpc.config.NUM_EPOCHS):
# training
engine.train()
data_iter = iter(train_dataloader)
if gpc.get_global_rank() == 0:
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
progress = tqdm(range(len(train_dataloader)), desc=description)
else:
progress = range(len(train_dataloader))
for _ in progress:
engine.zero_grad()
engine.execute_schedule(data_iter, return_output_label=False)
engine.step()
lr_scheduler.step()
if __name__ == '__main__':
main()