mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 05:29:36 +00:00
[exmaple] fix dreamblooth format (#2315)
This commit is contained in:
parent
da1c47f060
commit
a9b27b9265
@ -108,7 +108,7 @@ lightning:
|
|||||||
params:
|
params:
|
||||||
use_chunk: True
|
use_chunk: True
|
||||||
enable_distributed_storage: True
|
enable_distributed_storage: True
|
||||||
placement_policy: auto
|
placement_policy: cuda
|
||||||
force_outputs_fp32: true
|
force_outputs_fp32: true
|
||||||
|
|
||||||
log_every_n_steps: 2
|
log_every_n_steps: 2
|
||||||
|
@ -105,7 +105,7 @@ lightning:
|
|||||||
params:
|
params:
|
||||||
use_chunk: True
|
use_chunk: True
|
||||||
enable_distributed_storage: True
|
enable_distributed_storage: True
|
||||||
placement_policy: auto
|
placement_policy: cuda
|
||||||
force_outputs_fp32: true
|
force_outputs_fp32: true
|
||||||
|
|
||||||
log_every_n_steps: 2
|
log_every_n_steps: 2
|
||||||
|
@ -109,7 +109,7 @@ lightning:
|
|||||||
params:
|
params:
|
||||||
use_chunk: True
|
use_chunk: True
|
||||||
enable_distributed_storage: True
|
enable_distributed_storage: True
|
||||||
placement_policy: auto
|
placement_policy: cuda
|
||||||
force_outputs_fp32: true
|
force_outputs_fp32: true
|
||||||
|
|
||||||
log_every_n_steps: 2
|
log_every_n_steps: 2
|
||||||
|
@ -102,7 +102,7 @@ lightning:
|
|||||||
params:
|
params:
|
||||||
use_chunk: True
|
use_chunk: True
|
||||||
enable_distributed_storage: True
|
enable_distributed_storage: True
|
||||||
placement_policy: auto
|
placement_policy: cuda
|
||||||
force_outputs_fp32: true
|
force_outputs_fp32: true
|
||||||
|
|
||||||
log_every_n_steps: 2
|
log_every_n_steps: 2
|
||||||
|
@ -1,38 +1,32 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import hashlib
|
import hashlib
|
||||||
import itertools
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from copy import deepcopy
|
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
|
||||||
from packaging import version
|
|
||||||
from PIL import Image
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import AutoTokenizer, PretrainedConfig
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
|
||||||
from colossalai.nn.parallel.utils import convert_to_torch_module
|
from colossalai.nn.parallel.utils import convert_to_torch_module
|
||||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
from colossalai.tensor import ProcessGroup
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import AutoTokenizer, PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
@ -118,8 +112,10 @@ def parse_args(input_args=None):
|
|||||||
"--num_class_images",
|
"--num_class_images",
|
||||||
type=int,
|
type=int,
|
||||||
default=100,
|
default=100,
|
||||||
help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
|
help=(
|
||||||
" class_data_dir, additional images will be sampled with class_prompt."),
|
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
||||||
|
" class_data_dir, additional images will be sampled with class_prompt."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_dir",
|
"--output_dir",
|
||||||
@ -132,23 +128,26 @@ def parse_args(input_args=None):
|
|||||||
"--resolution",
|
"--resolution",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=512,
|
||||||
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
help=(
|
||||||
" resolution"),
|
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||||
|
" resolution"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--placement",
|
"--placement",
|
||||||
type=str,
|
type=str,
|
||||||
default='cpu',
|
default="cpu",
|
||||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--center_crop",
|
parser.add_argument(
|
||||||
action="store_true",
|
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||||
help="Whether to center crop images before resizing to resolution")
|
)
|
||||||
parser.add_argument("--train_batch_size",
|
parser.add_argument(
|
||||||
type=int,
|
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||||
default=4,
|
)
|
||||||
help="Batch size (per device) for the training dataloader.")
|
parser.add_argument(
|
||||||
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
|
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
||||||
|
)
|
||||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_train_steps",
|
"--max_train_steps",
|
||||||
@ -184,16 +183,17 @@ def parse_args(input_args=None):
|
|||||||
"--lr_scheduler",
|
"--lr_scheduler",
|
||||||
type=str,
|
type=str,
|
||||||
default="constant",
|
default="constant",
|
||||||
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
help=(
|
||||||
' "constant", "constant_with_warmup"]'),
|
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||||
|
' "constant", "constant_with_warmup"]'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||||
)
|
)
|
||||||
parser.add_argument("--lr_warmup_steps",
|
|
||||||
type=int,
|
|
||||||
default=500,
|
|
||||||
help="Number of steps for the warmup in the lr scheduler.")
|
|
||||||
parser.add_argument("--use_8bit_adam",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether or not to use 8-bit Adam from bitsandbytes.")
|
|
||||||
|
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||||
@ -208,8 +208,10 @@ def parse_args(input_args=None):
|
|||||||
"--logging_dir",
|
"--logging_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="logs",
|
default="logs",
|
||||||
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
help=(
|
||||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
|
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||||
|
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mixed_precision",
|
"--mixed_precision",
|
||||||
@ -219,7 +221,8 @@ def parse_args(input_args=None):
|
|||||||
help=(
|
help=(
|
||||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
|
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
|
|
||||||
@ -285,12 +288,14 @@ class DreamBoothDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
self.class_data_root = None
|
self.class_data_root = None
|
||||||
|
|
||||||
self.image_transforms = transforms.Compose([
|
self.image_transforms = transforms.Compose(
|
||||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
[
|
||||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||||
transforms.ToTensor(),
|
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||||
transforms.Normalize([0.5], [0.5]),
|
transforms.ToTensor(),
|
||||||
])
|
transforms.Normalize([0.5], [0.5]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._length
|
return self._length
|
||||||
@ -352,26 +357,11 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
|||||||
|
|
||||||
# Gemini + ZeRO DDP
|
# Gemini + ZeRO DDP
|
||||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
||||||
cai_version = colossalai.__version__
|
from colossalai.nn.parallel import GeminiDDP
|
||||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
|
||||||
from colossalai.nn.parallel import GeminiDDP
|
model = GeminiDDP(
|
||||||
model = GeminiDDP(model,
|
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32
|
||||||
device=get_current_device(),
|
)
|
||||||
placement_policy=placememt_policy,
|
|
||||||
pin_memory=True,
|
|
||||||
search_range_mb=32)
|
|
||||||
|
|
||||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
|
||||||
from colossalai.gemini import ChunkManager, GeminiManager
|
|
||||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
|
||||||
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
|
|
||||||
chunk_manager = ChunkManager(chunk_size,
|
|
||||||
pg,
|
|
||||||
enable_distributed_storage=True,
|
|
||||||
init_device=GeminiManager.get_default_device(placememt_policy))
|
|
||||||
model = ZeroDDP(model, gemini_manager)
|
|
||||||
else:
|
|
||||||
raise NotImplemented(f"CAI version {cai_version} is not supported")
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -383,7 +373,7 @@ def main(args):
|
|||||||
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||||
"clip_grad_norm": args.max_grad_norm,
|
"clip_grad_norm": args.max_grad_norm,
|
||||||
}
|
}
|
||||||
|
|
||||||
colossalai.launch_from_torch(config=config)
|
colossalai.launch_from_torch(config=config)
|
||||||
pg = ProcessGroup()
|
pg = ProcessGroup()
|
||||||
|
|
||||||
@ -414,9 +404,11 @@ def main(args):
|
|||||||
|
|
||||||
pipeline.to(get_current_device())
|
pipeline.to(get_current_device())
|
||||||
|
|
||||||
for example in tqdm(sample_dataloader,
|
for example in tqdm(
|
||||||
desc="Generating class images",
|
sample_dataloader,
|
||||||
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0):
|
desc="Generating class images",
|
||||||
|
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
|
||||||
|
):
|
||||||
images = pipeline(example["prompt"]).images
|
images = pipeline(example["prompt"]).images
|
||||||
|
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
@ -466,23 +458,24 @@ def main(args):
|
|||||||
|
|
||||||
logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])
|
logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||||
|
|
||||||
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path,
|
text_encoder = text_encoder_cls.from_pretrained(
|
||||||
subfolder="text_encoder",
|
args.pretrained_model_name_or_path,
|
||||||
revision=args.revision,)
|
subfolder="text_encoder",
|
||||||
|
revision=args.revision,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
|
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,
|
vae = AutoencoderKL.from_pretrained(
|
||||||
subfolder="vae",
|
args.pretrained_model_name_or_path,
|
||||||
revision=args.revision,)
|
subfolder="vae",
|
||||||
|
revision=args.revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||||
with ColoInitContext():
|
with ColoInitContext():
|
||||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
unet = UNet2DConditionModel.from_pretrained(
|
||||||
subfolder="unet",
|
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
|
||||||
revision=args.revision,
|
)
|
||||||
low_cpu_mem_usage=False)
|
|
||||||
|
|
||||||
|
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
@ -491,7 +484,7 @@ def main(args):
|
|||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2)
|
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2
|
||||||
|
|
||||||
unet = gemini_zero_dpp(unet, pg, args.placement)
|
unet = gemini_zero_dpp(unet, pg, args.placement)
|
||||||
|
|
||||||
@ -502,7 +495,7 @@ def main(args):
|
|||||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||||
|
|
||||||
# prepare dataset
|
# prepare dataset
|
||||||
logger.info(f"Prepare dataset", ranks=[0])
|
logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
|
||||||
train_dataset = DreamBoothDataset(
|
train_dataset = DreamBoothDataset(
|
||||||
instance_data_root=args.instance_data_dir,
|
instance_data_root=args.instance_data_dir,
|
||||||
instance_prompt=args.instance_prompt,
|
instance_prompt=args.instance_prompt,
|
||||||
@ -527,9 +520,7 @@ def main(args):
|
|||||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||||
|
|
||||||
input_ids = tokenizer.pad(
|
input_ids = tokenizer.pad(
|
||||||
{
|
{"input_ids": input_ids},
|
||||||
"input_ids": input_ids
|
|
||||||
},
|
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=tokenizer.model_max_length,
|
max_length=tokenizer.model_max_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@ -541,11 +532,9 @@ def main(args):
|
|||||||
}
|
}
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(train_dataset,
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
batch_size=args.train_batch_size,
|
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
|
||||||
shuffle=True,
|
)
|
||||||
collate_fn=collate_fn,
|
|
||||||
num_workers=1)
|
|
||||||
|
|
||||||
# Scheduler and math around the number of training steps.
|
# Scheduler and math around the number of training steps.
|
||||||
overrode_max_train_steps = False
|
overrode_max_train_steps = False
|
||||||
@ -662,8 +651,8 @@ def main(args):
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
logs = {
|
logs = {
|
||||||
"loss": loss.detach().item(),
|
"loss": loss.detach().item(),
|
||||||
"lr": optimizer.param_groups[0]['lr']
|
"lr": optimizer.param_groups[0]["lr"],
|
||||||
} #lr_scheduler.get_last_lr()[0]}
|
} # lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
@ -681,15 +670,15 @@ def main(args):
|
|||||||
break
|
break
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
unet=convert_to_torch_module(unet)
|
unet = convert_to_torch_module(unet)
|
||||||
|
|
||||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
revision=args.revision,
|
revision=args.revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline.save_pretrained(args.output_dir)
|
pipeline.save_pretrained(args.output_dir)
|
||||||
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user