From a9b27b9265c31175192643e3974187e5ea112c1d Mon Sep 17 00:00:00 2001 From: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com> Date: Wed, 4 Jan 2023 16:20:00 +0800 Subject: [PATCH] [exmaple] fix dreamblooth format (#2315) --- .../Teyvat/train_colossalai_teyvat.yaml | 2 +- .../diffusion/configs/train_colossalai.yaml | 2 +- .../configs/train_colossalai_cifar10.yaml | 2 +- .../diffusion/configs/train_pokemon.yaml | 2 +- .../dreambooth/train_dreambooth_colossalai.py | 183 ++++++++---------- 5 files changed, 90 insertions(+), 101 deletions(-) diff --git a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml index 9048b3f80..d466c1c56 100644 --- a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml +++ b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml @@ -108,7 +108,7 @@ lightning: params: use_chunk: True enable_distributed_storage: True - placement_policy: auto + placement_policy: cuda force_outputs_fp32: true log_every_n_steps: 2 diff --git a/examples/images/diffusion/configs/train_colossalai.yaml b/examples/images/diffusion/configs/train_colossalai.yaml index e8df63bf6..0354311f8 100644 --- a/examples/images/diffusion/configs/train_colossalai.yaml +++ b/examples/images/diffusion/configs/train_colossalai.yaml @@ -105,7 +105,7 @@ lightning: params: use_chunk: True enable_distributed_storage: True - placement_policy: auto + placement_policy: cuda force_outputs_fp32: true log_every_n_steps: 2 diff --git a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml index 5335bacbe..0273ca862 100644 --- a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml +++ b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml @@ -109,7 +109,7 @@ lightning: params: use_chunk: True enable_distributed_storage: True - placement_policy: auto + placement_policy: cuda force_outputs_fp32: true log_every_n_steps: 2 diff --git a/examples/images/diffusion/configs/train_pokemon.yaml b/examples/images/diffusion/configs/train_pokemon.yaml index 38e8485a3..aadb5f2a0 100644 --- a/examples/images/diffusion/configs/train_pokemon.yaml +++ b/examples/images/diffusion/configs/train_pokemon.yaml @@ -102,7 +102,7 @@ lightning: params: use_chunk: True enable_distributed_storage: True - placement_policy: auto + placement_policy: cuda force_outputs_fp32: true log_every_n_steps: 2 diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 92a8aa28a..aff4d925d 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -1,38 +1,32 @@ import argparse import hashlib -import itertools import math import os from pathlib import Path from typing import Optional -import numpy as np import torch -import torch.distributed as dist import torch.nn.functional as F import torch.utils.checkpoint -from copy import deepcopy -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.optimization import get_scheduler -from huggingface_hub import HfFolder, Repository, whoami -from packaging import version -from PIL import Image -from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import Dataset -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel.utils import convert_to_torch_module -from colossalai.tensor import ColoTensor, ProcessGroup +from colossalai.tensor import ProcessGroup from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + disable_existing_loggers() logger = get_dist_logger() @@ -118,8 +112,10 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -132,23 +128,26 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), ) parser.add_argument( "--placement", type=str, - default='cpu', + default="cpu", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", ) - parser.add_argument("--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution") - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") - parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--max_train_steps", @@ -184,16 +183,17 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -208,8 +208,10 @@ def parse_args(input_args=None): "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", @@ -219,7 +221,8 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -285,12 +288,14 @@ class DreamBoothDataset(Dataset): else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -352,26 +357,11 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: # Gemini + ZeRO DDP def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): - cai_version = colossalai.__version__ - if version.parse(cai_version) > version.parse("0.1.10"): - from colossalai.nn.parallel import GeminiDDP - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placememt_policy, - pin_memory=True, - search_range_mb=32) - - elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): - from colossalai.gemini import ChunkManager, GeminiManager - chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - gemini_manager = GeminiManager(placememt_policy, chunk_manager) - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(placememt_policy)) - model = ZeroDDP(model, gemini_manager) - else: - raise NotImplemented(f"CAI version {cai_version} is not supported") + from colossalai.nn.parallel import GeminiDDP + + model = GeminiDDP( + model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32 + ) return model @@ -383,7 +373,7 @@ def main(args): "gradient_accumulation_steps": args.gradient_accumulation_steps, "clip_grad_norm": args.max_grad_norm, } - + colossalai.launch_from_torch(config=config) pg = ProcessGroup() @@ -414,9 +404,11 @@ def main(args): pipeline.to(get_current_device()) - for example in tqdm(sample_dataloader, - desc="Generating class images", - disable=not gpc.get_local_rank(ParallelMode.DATA) == 0): + for example in tqdm( + sample_dataloader, + desc="Generating class images", + disable=not gpc.get_local_rank(ParallelMode.DATA) == 0, + ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): @@ -466,23 +458,24 @@ def main(args): logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0]) - text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision,) + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0]) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, - subfolder="vae", - revision=args.revision,) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) - logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) with ColoInitContext(): - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) - + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False + ) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -491,7 +484,7 @@ def main(args): unet.enable_gradient_checkpointing() if args.scale_lr: - args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2) + args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2 unet = gemini_zero_dpp(unet, pg, args.placement) @@ -502,7 +495,7 @@ def main(args): noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # prepare dataset - logger.info(f"Prepare dataset", ranks=[0]) + logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0]) train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, @@ -527,9 +520,7 @@ def main(args): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - { - "input_ids": input_ids - }, + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -541,11 +532,9 @@ def main(args): } return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=1) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -662,8 +651,8 @@ def main(args): global_step += 1 logs = { "loss": loss.detach().item(), - "lr": optimizer.param_groups[0]['lr'] - } #lr_scheduler.get_last_lr()[0]} + "lr": optimizer.param_groups[0]["lr"], + } # lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step % args.save_steps == 0: @@ -681,15 +670,15 @@ def main(args): break torch.cuda.synchronize() - unet=convert_to_torch_module(unet) - + unet = convert_to_torch_module(unet) + if gpc.get_local_rank(ParallelMode.DATA) == 0: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unet, revision=args.revision, ) - + pipeline.save_pretrained(args.output_dir) logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])