diff --git a/examples/images/resnet/.gitignore b/examples/images/resnet/.gitignore new file mode 100644 index 000000000..a79cf5236 --- /dev/null +++ b/examples/images/resnet/.gitignore @@ -0,0 +1,4 @@ +data +checkpoint +ckpt-fp16 +ckpt-fp32 diff --git a/examples/images/resnet/README.md b/examples/images/resnet/README.md new file mode 100644 index 000000000..c69828637 --- /dev/null +++ b/examples/images/resnet/README.md @@ -0,0 +1,56 @@ +# Train ResNet on CIFAR-10 from scratch + +## 🚀 Quick Start + +This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch. + +- Training Arguments + - `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `low_level_zero`. Defaults to `torch_ddp`. + - `-r`, `--resume`: Resume from checkpoint file path. Defaults to `-1`, which means not resuming. + - `-c`, `--checkpoint`: The folder to save checkpoints. Defaults to `./checkpoint`. + - `-i`, `--interval`: Epoch interval to save checkpoints. Defaults to `5`. If set to `0`, no checkpoint will be saved. + - `--target_acc`: Target accuracy. Raise exception if not reached. Defaults to `None`. + +- Eval Arguments + - `-e`, `--epoch`: select the epoch to evaluate + - `-c`, `--checkpoint`: the folder where checkpoints are found + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +### Train +The folders will be created automatically. +```bash +# train with torch DDP with fp32 +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32 + +# train with torch DDP with mixed precision training +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 -p torch_ddp_fp16 + +# train with low level zero +colossalai run --nproc_per_node 2 train.py -c ./ckpt-low_level_zero -p low_level_zero +``` + +### Eval + +```bash +# evaluate fp32 training +python eval.py -c ./ckpt-fp32 -e 80 + +# evaluate fp16 mixed precision training +python eval.py -c ./ckpt-fp16 -e 80 + +# evaluate low level zero training +python eval.py -c ./ckpt-low_level_zero -e 80 +``` + +Expected accuracy performance will be: + +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | +| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | +| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | + +**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** diff --git a/examples/images/resnet/eval.py b/examples/images/resnet/eval.py new file mode 100644 index 000000000..657708ec3 --- /dev/null +++ b/examples/images/resnet/eval.py @@ -0,0 +1,48 @@ +import argparse + +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms + +# ============================== +# Parse Arguments +# ============================== +parser = argparse.ArgumentParser() +parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint") +parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") +args = parser.parse_args() + +# ============================== +# Prepare Test Dataset +# ============================== +# CIFAR-10 dataset +test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor()) + +# Data loader +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False) + +# ============================== +# Load Model +# ============================== +model = torchvision.models.resnet18(num_classes=10).cuda() +state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth') +model.load_state_dict(state_dict) + +# ============================== +# Run Evaluation +# ============================== +model.eval() + +with torch.no_grad(): + correct = 0 + total = 0 + for images, labels in test_loader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print('Accuracy of the model on the test images: {} %'.format(100 * correct / total)) diff --git a/examples/images/resnet/requirements.txt b/examples/images/resnet/requirements.txt new file mode 100644 index 000000000..3c7da7743 --- /dev/null +++ b/examples/images/resnet/requirements.txt @@ -0,0 +1,5 @@ +colossalai +torch +torchvision +tqdm +pytest \ No newline at end of file diff --git a/examples/images/resnet/test_ci.sh b/examples/images/resnet/test_ci.sh new file mode 100755 index 000000000..b3fb67830 --- /dev/null +++ b/examples/images/resnet/test_ci.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -xe + +export DATA=/data/scratch/cifar-10 + +pip install -r requirements.txt + +# TODO: skip ci test due to time limits, train.py needs to be rewritten. + +# for plugin in "torch_ddp" "torch_ddp_fp16" "low_level_zero"; do +# colossalai run --nproc_per_node 4 train.py --interval 0 --target_acc 0.84 --plugin $plugin +# done diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py new file mode 100644 index 000000000..fe0dabf08 --- /dev/null +++ b/examples/images/resnet/train.py @@ -0,0 +1,204 @@ +import argparse +import os +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from torch.optim import Optimizer +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data import DataLoader +from tqdm import tqdm + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 80 +LEARNING_RATE = 1e-3 + + +def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): + # transform + transform_train = transforms.Compose( + [transforms.Pad(4), + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32), + transforms.ToTensor()]) + transform_test = transforms.ToTensor() + + # CIFAR-10 dataset + data_path = os.environ.get('DATA', './data') + with coordinator.priority_execution(): + train_dataset = torchvision.datasets.CIFAR10(root=data_path, + train=True, + transform=transform_train, + download=True) + test_dataset = torchvision.datasets.CIFAR10(root=data_path, + train=False, + transform=transform_test, + download=True) + + # Data loader + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) + return train_dataloader, test_dataloader + + +@torch.no_grad() +def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: + model.eval() + correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + for images, labels in test_dataloader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + dist.all_reduce(correct) + dist.all_reduce(total) + accuracy = correct.item() / total.item() + if coordinator.is_master(): + print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %') + return accuracy + + +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader, + booster: Booster, coordinator: DistCoordinator): + model.train() + with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + for images, labels in pbar: + images = images.cuda() + labels = labels.cuda() + # Forward pass + outputs = model(images) + loss = criterion(outputs, labels) + + # Backward and optimize + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + # Print log info + pbar.set_postfix({'loss': loss.item()}) + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + # FIXME(ver217): gemini is not supported resnet now + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'], + help="plugin to use") + parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") + parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") + parser.add_argument('--target_acc', + type=float, + default=None, + help="target accuracy. Raise exception if not reached") + args = parser.parse_args() + + # ============================== + # Prepare Checkpoint Directory + # ============================== + if args.interval > 0: + Path(args.checkpoint).mkdir(parents=True, exist_ok=True) + + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}) + coordinator = DistCoordinator() + + # update the learning rate with linear scaling + # old_gpu_num / old_lr = new_gpu_num / new_lr + global LEARNING_RATE + LEARNING_RATE *= coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + train_dataloader, test_dataloader = build_dataloader(100, coordinator, plugin) + + # ==================================== + # Prepare model, optimizer, criterion + # ==================================== + # resent50 + model = torchvision.models.resnet18(num_classes=10) + + # Loss and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE) + + # lr scheduler + lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3) + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler) + + # ============================== + # Resume from checkpoint + # ============================== + if args.resume >= 0: + booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') + booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') + booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + + # ============================== + # Train model + # ============================== + start_epoch = args.resume if args.resume >= 0 else 0 + for epoch in range(start_epoch, NUM_EPOCHS): + train_epoch(epoch, model, optimizer, criterion, train_dataloader, booster, coordinator) + lr_scheduler.step() + + # save checkpoint + if args.interval > 0 and (epoch + 1) % args.interval == 0: + booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') + booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') + booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') + + accuracy = evaluate(model, test_dataloader, coordinator) + if args.target_acc is not None: + assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}' + + +if __name__ == '__main__': + main()