From f03bcb359bfae1bf1aa420dfa3608b4e10c624a1 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 4 Jan 2022 20:35:33 +0800 Subject: [PATCH] update vit example for new API (#98) (#99) --- .../dataloader/imagenet_dali_dataloader.py | 10 +++++----- examples/vit_b16_imagenet_data_parallel/mixup.py | 13 +++++++++++-- examples/vit_b16_imagenet_data_parallel/train.py | 6 +++--- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py b/examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py index a39d73e26..c7032789a 100755 --- a/examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py +++ b/examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py @@ -104,9 +104,9 @@ class DaliDataloader(DALIClassificationIterator): img = lam * img + (1 - lam) * img[idx, :] label_a, label_b = label, label[idx] lam = torch.tensor([lam], device=img.device, dtype=img.dtype) - label = (label_a, label_b, lam) + label = {'targets_a': label_a, 'targets_b': label_b, 'lam': lam} else: - label = (label, label, torch.ones( - 1, device=img.device, dtype=img.dtype)) - return (img,), label - return (img,), (label,) + label = {'targets_a': label, 'targets_b': label, + 'lam': torch.ones(1, device=img.device, dtype=img.dtype)} + return img, label + return img, label diff --git a/examples/vit_b16_imagenet_data_parallel/mixup.py b/examples/vit_b16_imagenet_data_parallel/mixup.py index 822bc8659..af097ef3b 100644 --- a/examples/vit_b16_imagenet_data_parallel/mixup.py +++ b/examples/vit_b16_imagenet_data_parallel/mixup.py @@ -1,5 +1,7 @@ import torch.nn as nn from colossalai.registry import LOSSES +import torch + @LOSSES.register_module class MixupLoss(nn.Module): @@ -7,6 +9,13 @@ class MixupLoss(nn.Module): super().__init__() self.loss_fn = loss_fn_cls() - def forward(self, inputs, *args): - targets_a, targets_b, lam = args + def forward(self, inputs, targets_a, targets_b, lam): return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b) + + +class MixupAccuracy(nn.Module): + def forward(self, logits, targets): + targets = targets['targets_a'] + preds = torch.argmax(logits, dim=-1) + correct = torch.sum(targets == preds) + return correct diff --git a/examples/vit_b16_imagenet_data_parallel/train.py b/examples/vit_b16_imagenet_data_parallel/train.py index bf5845218..d67a71203 100644 --- a/examples/vit_b16_imagenet_data_parallel/train.py +++ b/examples/vit_b16_imagenet_data_parallel/train.py @@ -11,7 +11,7 @@ from colossalai.logging import get_dist_logger from colossalai.trainer import Trainer, hooks from colossalai.nn.lr_scheduler import LinearWarmupLR from dataloader.imagenet_dali_dataloader import DaliDataloader -from mixup import MixupLoss +from mixup import MixupLoss, MixupAccuracy from timm.models import vit_base_patch16_224 from myhooks import TotalBatchsizeHook @@ -62,7 +62,7 @@ def main(): port=args.port, backend=args.backend ) - # launch from torch + # launch from torch # colossalai.launch_from_torch(config=args.config) # get logger @@ -96,7 +96,7 @@ def main(): # build hooks hook_list = [ hooks.LossHook(), - hooks.AccuracyHook(accuracy_func=Accuracy()), + hooks.AccuracyHook(accuracy_func=MixupAccuracy()), hooks.LogMetricByEpochHook(logger), hooks.LRSchedulerHook(lr_scheduler, by_epoch=True), TotalBatchsizeHook(),