From d143396cac11811901066582b52d228cb9ab4192 Mon Sep 17 00:00:00 2001 From: LuGY_mac Date: Fri, 14 Jan 2022 19:22:17 +0800 Subject: [PATCH] Added rand augment and update the dataloader --- .../dataloader/imagenet_dali_dataloader.py | 46 ++-- .../dataloader/rand_augment.py | 209 ++++++++++++++++++ .../vit_b16_imagenet_data_parallel/train.py | 4 +- 3 files changed, 233 insertions(+), 26 deletions(-) mode change 100755 => 100644 examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py create mode 100644 examples/vit_b16_imagenet_data_parallel/dataloader/rand_augment.py diff --git a/examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py b/examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py old mode 100755 new mode 100644 index c7032789a..459266190 --- a/examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py +++ b/examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py @@ -5,6 +5,7 @@ import nvidia.dali.types as types import nvidia.dali.tfrecord as tfrec import torch import numpy as np +from .rand_augment import RandAugment class DaliDataloader(DALIClassificationIterator): @@ -21,13 +22,17 @@ class DaliDataloader(DALIClassificationIterator): training=True, gpu_aug=False, cuda=True, - mixup_alpha=0.0): + mixup_alpha=0.0, + randaug_magnitude=10, + randaug_num_layers=0): self.mixup_alpha = mixup_alpha self.training = training + self.randaug_magnitude = randaug_magnitude + self.randaug_num_layers = randaug_num_layers pipe = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=torch.cuda.current_device() if cuda else None, - seed=1024) + seed=42) with pipe: inputs = fn.readers.tfrecord( path=tfrec_filenames, @@ -44,38 +49,27 @@ class DaliDataloader(DALIClassificationIterator): 'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1), }) images = inputs["image/encoded"] - + images = fn.decoders.image(images, + device='mixed' if gpu_aug else 'cpu', + output_type=types.RGB) if training: - images = fn.decoders.image(images, - device='mixed' if gpu_aug else 'cpu', - output_type=types.RGB) images = fn.random_resized_crop(images, size=crop, device='gpu' if gpu_aug else 'cpu') - flip_lr = fn.random.coin_flip(probability=0.5) + if randaug_num_layers == 0: + flip_lr = fn.random.coin_flip(probability=0.5) + images = fn.flip(images, horizontal=flip_lr) else: - # decode jpeg and resize - images = fn.decoders.image(images, - device='mixed' if gpu_aug else 'cpu', - output_type=types.RGB) images = fn.resize(images, device='gpu' if gpu_aug else 'cpu', resize_x=resize, resize_y=resize, dtype=types.FLOAT, interp_type=types.INTERP_TRIANGULAR) - flip_lr = False - - # center crop and normalise - images = fn.crop_mirror_normalize(images, - dtype=types.FLOAT, - crop=(crop, crop), - mean=[127.5], - std=[127.5], - mirror=flip_lr) + images = fn.crop(images, + dtype=types.FLOAT, + crop=(crop, crop)) label = inputs["image/class/label"] - 1 # 0-999 - # LSG: element_extract will raise exception, let's flatten outside - # label = fn.element_extract(label, element_map=0) # Flatten if cuda: # transfer data to gpu pipe.set_outputs(images.gpu(), label.gpu()) else: @@ -96,6 +90,10 @@ class DaliDataloader(DALIClassificationIterator): def __next__(self): data = super().__next__() img, label = data[0]['data'], data[0]['label'] + img = img.permute(0, 3, 1, 2) + if self.randaug_num_layers > 0 and self.training: + img = RandAugment(img, num_layers=self.randaug_num_layers, magnitude=self.randaug_magnitude) + img = (img - 127.5) / 127.5 label = label.squeeze() if self.mixup_alpha > 0.0: if self.training: @@ -106,7 +104,7 @@ class DaliDataloader(DALIClassificationIterator): lam = torch.tensor([lam], device=img.device, dtype=img.dtype) label = {'targets_a': label_a, 'targets_b': label_b, 'lam': lam} else: - label = {'targets_a': label, 'targets_b': label, - 'lam': torch.ones(1, device=img.device, dtype=img.dtype)} + label = {'targets_a': label, 'targets_b': label, 'lam': torch.ones( + 1, device=img.device, dtype=img.dtype)} return img, label return img, label diff --git a/examples/vit_b16_imagenet_data_parallel/dataloader/rand_augment.py b/examples/vit_b16_imagenet_data_parallel/dataloader/rand_augment.py new file mode 100644 index 000000000..72118d5e8 --- /dev/null +++ b/examples/vit_b16_imagenet_data_parallel/dataloader/rand_augment.py @@ -0,0 +1,209 @@ +import torch +import numpy as np +import torchvision.transforms.functional as TF + +_MAX_LEVEL = 10 + +_HPARAMS = { + 'cutout_const': 40, + 'translate_const': 40, +} + +_FILL = tuple([128, 128, 128]) +# RGB + + +def blend(image0, image1, factor): + # blend image0 with image1 + # we only use this function in the 'color' function + if factor == 0.0: + return image0 + if factor == 1.0: + return image1 + image0 = image0.type(torch.float32) + image1 = image1.type(torch.float32) + scaled = (image1 - image0) * factor + image = image0 + scaled + + if factor > 0.0 and factor < 1.0: + return image.type(torch.uint8) + + image = torch.clamp(image, 0, 255).type(torch.uint8) + return image + + +def autocontrast(image): + image = TF.autocontrast(image) + return image + + +def equalize(image): + image = TF.equalize(image) + return image + + +def rotate(image, degree, fill=_FILL): + image = TF.rotate(image, angle=degree, fill=fill) + return image + + +def posterize(image, bits): + image = TF.posterize(image, bits) + return image + + +def sharpness(image, factor): + image = TF.adjust_sharpness(image, sharpness_factor=factor) + return image + + +def contrast(image, factor): + image = TF.adjust_contrast(image, factor) + return image + + +def brightness(image, factor): + image = TF.adjust_brightness(image, factor) + return image + + +def invert(image): + return 255-image + + +def solarize(image, threshold=128): + return torch.where(image < threshold, image, 255-image) + + +def solarize_add(image, addition=0, threshold=128): + add_image = image.long() + addition + add_image = torch.clamp(add_image, 0, 255).type(torch.uint8) + return torch.where(image < threshold, add_image, image) + + +def color(image, factor): + new_image = TF.rgb_to_grayscale(image, num_output_channels=3) + return blend(new_image, image, factor=factor) + + +def shear_x(image, level, fill=_FILL): + image = TF.affine(image, 0, [0, 0], 1.0, [level, 0], fill=fill) + return image + + +def shear_y(image, level, fill=_FILL): + image = TF.affine(image, 0, [0, 0], 1.0, [0, level], fill=fill) + return image + + +def translate_x(image, level, fill=_FILL): + image = TF.affine(image, 0, [level, 0], 1.0, [0, 0], fill=fill) + return image + + +def translate_y(image, level, fill=_FILL): + image = TF.affine(image, 0, [0, level], 1.0, [0, 0], fill=fill) + return image + + +def cutout(image, pad_size, fill=_FILL): + b, c, h, w = image.shape + mask = torch.ones((b, c, h, w), dtype=torch.uint8).cuda() + y = np.random.randint(pad_size, h-pad_size) + x = np.random.randint(pad_size, w-pad_size) + for i in range(c): + mask[:, i, (y-pad_size): (y+pad_size), (x-pad_size): (x+pad_size)] = fill[i] + image = torch.where(mask == 1, image, mask) + return image + + +def _randomly_negate_tensor(level): + # With 50% prob turn the tensor negative. + flip = np.random.randint(0, 2) + final_level = -level if flip else level + return final_level + + +def _rotate_level_to_arg(level): + level = (level/_MAX_LEVEL) * 30. + level = _randomly_negate_tensor(level) + return level + + +def _shear_level_to_arg(level): + level = (level/_MAX_LEVEL) * 0.3 + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return level + + +def _translate_level_to_arg(level, translate_const): + level = (level/_MAX_LEVEL) * float(translate_const) + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return level + + +def level(hparams): + return { + 'AutoContrast': lambda level: None, + 'Equalize': lambda level: None, + 'Invert': lambda level: None, + 'Rotate': _rotate_level_to_arg, + 'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4)), + 'Solarize': lambda level: (int((level/_MAX_LEVEL) * 200)), + 'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110)), + 'Color': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1), + 'Contrast': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1), + 'Brightness': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1), + 'Sharpness': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1), + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams['cutout_const'])), + 'TranslateX': lambda level: _translate_level_to_arg(level, hparams['translate_const']), + 'TranslateY': lambda level: _translate_level_to_arg(level, hparams['translate_const']), + } + + +AUGMENTS = { + 'AutoContrast': autocontrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x, + 'TranslateY': translate_y, + 'Cutout': cutout, +} + + +def RandAugment(image, num_layers=2, magnitude=_MAX_LEVEL, augments=AUGMENTS): + """Random Augment for images, followed google randaug and the paper(https://arxiv.org/abs/2106.10270) + :param image: the input image, in tensor format with shape of C, H, W + :type image: uint8 Tensor + :num_layers: how many layers will the randaug do, default=2 + :type num_layers: int + :param magnitude: the magnitude of random augment, default=10 + :type magnitude: int + """ + if np.random.random() < 0.5: + return image + Choice_Augment = np.random.choice(a=list(augments.keys()), + size=num_layers, + replace=False) + magnitude = float(magnitude) + for i in range(num_layers): + arg = level(_HPARAMS)[Choice_Augment[i]](magnitude) + if arg is None: + image = augments[Choice_Augment[i]](image) + else: + image = augments[Choice_Augment[i]](image, arg) + return image diff --git a/examples/vit_b16_imagenet_data_parallel/train.py b/examples/vit_b16_imagenet_data_parallel/train.py index d67a71203..1d324a1b2 100644 --- a/examples/vit_b16_imagenet_data_parallel/train.py +++ b/examples/vit_b16_imagenet_data_parallel/train.py @@ -26,10 +26,10 @@ def build_dali_train(): batch_size=gpc.config.BATCH_SIZE, shard_id=gpc.get_local_rank(ParallelMode.DATA), num_shards=gpc.get_world_size(ParallelMode.DATA), - training=True, gpu_aug=gpc.config.dali.gpu_aug, cuda=True, - mixup_alpha=gpc.config.dali.mixup_alpha + mixup_alpha=gpc.config.dali.mixup_alpha, + randaug_num_layers=2 )