[exmaple] fix dreamblooth format (#2315)

This commit is contained in:
Fazzie-Maqianli 2023-01-04 16:20:00 +08:00 committed by GitHub
parent da1c47f060
commit a9b27b9265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 90 additions and 101 deletions

View File

@ -108,7 +108,7 @@ lightning:
params: params:
use_chunk: True use_chunk: True
enable_distributed_storage: True enable_distributed_storage: True
placement_policy: auto placement_policy: cuda
force_outputs_fp32: true force_outputs_fp32: true
log_every_n_steps: 2 log_every_n_steps: 2

View File

@ -105,7 +105,7 @@ lightning:
params: params:
use_chunk: True use_chunk: True
enable_distributed_storage: True enable_distributed_storage: True
placement_policy: auto placement_policy: cuda
force_outputs_fp32: true force_outputs_fp32: true
log_every_n_steps: 2 log_every_n_steps: 2

View File

@ -109,7 +109,7 @@ lightning:
params: params:
use_chunk: True use_chunk: True
enable_distributed_storage: True enable_distributed_storage: True
placement_policy: auto placement_policy: cuda
force_outputs_fp32: true force_outputs_fp32: true
log_every_n_steps: 2 log_every_n_steps: 2

View File

@ -102,7 +102,7 @@ lightning:
params: params:
use_chunk: True use_chunk: True
enable_distributed_storage: True enable_distributed_storage: True
placement_policy: auto placement_policy: cuda
force_outputs_fp32: true force_outputs_fp32: true
log_every_n_steps: 2 log_every_n_steps: 2

View File

@ -1,38 +1,32 @@
import argparse import argparse
import hashlib import hashlib
import itertools
import math import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import numpy as np
import torch import torch
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from copy import deepcopy
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from packaging import version
from PIL import Image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import colossalai import colossalai
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.parallel.utils import convert_to_torch_module from colossalai.nn.parallel.utils import convert_to_torch_module
from colossalai.tensor import ColoTensor, ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
@ -118,8 +112,10 @@ def parse_args(input_args=None):
"--num_class_images", "--num_class_images",
type=int, type=int,
default=100, default=100,
help=("Minimal class images for prior preservation loss. If there are not enough images already present in" help=(
" class_data_dir, additional images will be sampled with class_prompt."), "Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
) )
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
@ -132,23 +128,26 @@ def parse_args(input_args=None):
"--resolution", "--resolution",
type=int, type=int,
default=512, default=512,
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" help=(
" resolution"), "The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
) )
parser.add_argument( parser.add_argument(
"--placement", "--placement",
type=str, type=str,
default='cpu', default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
) )
parser.add_argument("--center_crop", parser.add_argument(
action="store_true", "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
help="Whether to center crop images before resizing to resolution") )
parser.add_argument("--train_batch_size", parser.add_argument(
type=int, "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
default=4, )
help="Batch size (per device) for the training dataloader.") parser.add_argument(
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument( parser.add_argument(
"--max_train_steps", "--max_train_steps",
@ -184,16 +183,17 @@ def parse_args(input_args=None):
"--lr_scheduler", "--lr_scheduler",
type=str, type=str,
default="constant", default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' help=(
' "constant", "constant_with_warmup"]'), 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
) )
parser.add_argument("--lr_warmup_steps",
type=int,
default=500,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
@ -208,8 +208,10 @@ def parse_args(input_args=None):
"--logging_dir", "--logging_dir",
type=str, type=str,
default="logs", default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" help=(
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
) )
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
@ -219,7 +221,8 @@ def parse_args(input_args=None):
help=( help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
) )
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@ -285,12 +288,14 @@ class DreamBoothDataset(Dataset):
else: else:
self.class_data_root = None self.class_data_root = None
self.image_transforms = transforms.Compose([ self.image_transforms = transforms.Compose(
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), [
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.Normalize([0.5], [0.5]), transforms.ToTensor(),
]) transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self): def __len__(self):
return self._length return self._length
@ -352,26 +357,11 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__ from colossalai.nn.parallel import GeminiDDP
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP model = GeminiDDP(
model = GeminiDDP(model, model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32
device=get_current_device(), )
placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model return model
@ -383,7 +373,7 @@ def main(args):
"gradient_accumulation_steps": args.gradient_accumulation_steps, "gradient_accumulation_steps": args.gradient_accumulation_steps,
"clip_grad_norm": args.max_grad_norm, "clip_grad_norm": args.max_grad_norm,
} }
colossalai.launch_from_torch(config=config) colossalai.launch_from_torch(config=config)
pg = ProcessGroup() pg = ProcessGroup()
@ -414,9 +404,11 @@ def main(args):
pipeline.to(get_current_device()) pipeline.to(get_current_device())
for example in tqdm(sample_dataloader, for example in tqdm(
desc="Generating class images", sample_dataloader,
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0): desc="Generating class images",
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
):
images = pipeline(example["prompt"]).images images = pipeline(example["prompt"]).images
for i, image in enumerate(images): for i, image in enumerate(images):
@ -466,23 +458,24 @@ def main(args):
logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0]) logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path, text_encoder = text_encoder_cls.from_pretrained(
subfolder="text_encoder", args.pretrained_model_name_or_path,
revision=args.revision,) subfolder="text_encoder",
revision=args.revision,
)
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0]) logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, vae = AutoencoderKL.from_pretrained(
subfolder="vae", args.pretrained_model_name_or_path,
revision=args.revision,) subfolder="vae",
revision=args.revision,
)
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
with ColoInitContext(): with ColoInitContext():
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, unet = UNet2DConditionModel.from_pretrained(
subfolder="unet", args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
revision=args.revision, )
low_cpu_mem_usage=False)
vae.requires_grad_(False) vae.requires_grad_(False)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
@ -491,7 +484,7 @@ def main(args):
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.scale_lr: if args.scale_lr:
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2) args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2
unet = gemini_zero_dpp(unet, pg, args.placement) unet = gemini_zero_dpp(unet, pg, args.placement)
@ -502,7 +495,7 @@ def main(args):
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# prepare dataset # prepare dataset
logger.info(f"Prepare dataset", ranks=[0]) logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir, instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
@ -527,9 +520,7 @@ def main(args):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad( input_ids = tokenizer.pad(
{ {"input_ids": input_ids},
"input_ids": input_ids
},
padding="max_length", padding="max_length",
max_length=tokenizer.model_max_length, max_length=tokenizer.model_max_length,
return_tensors="pt", return_tensors="pt",
@ -541,11 +532,9 @@ def main(args):
} }
return batch return batch
train_dataloader = torch.utils.data.DataLoader(train_dataset, train_dataloader = torch.utils.data.DataLoader(
batch_size=args.train_batch_size, train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
shuffle=True, )
collate_fn=collate_fn,
num_workers=1)
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
@ -662,8 +651,8 @@ def main(args):
global_step += 1 global_step += 1
logs = { logs = {
"loss": loss.detach().item(), "loss": loss.detach().item(),
"lr": optimizer.param_groups[0]['lr'] "lr": optimizer.param_groups[0]["lr"],
} #lr_scheduler.get_last_lr()[0]} } # lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
@ -681,15 +670,15 @@ def main(args):
break break
torch.cuda.synchronize() torch.cuda.synchronize()
unet=convert_to_torch_module(unet) unet = convert_to_torch_module(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0: if gpc.get_local_rank(ParallelMode.DATA) == 0:
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=unet, unet=unet,
revision=args.revision, revision=args.revision,
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])