[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -93,7 +93,7 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
```
## New API
We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`.
We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`.
We have also offer a shell script `test_ci.sh` for you to go through all our plugins for the booster.
For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/.
@@ -111,7 +111,7 @@ For more information about the booster API you can refer to https://colossalai.o
| low_level_zero | 4 | 8 | 28.87 | 2.02 |
The evaluation is performed on 4 Nvidia A100 GPUs with 80GB memory each, with GPU 0 & 1, 2 & 3 connected with NVLink.
We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared
We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared
the memory cost and the throughput for the plugins.

View File

@@ -1,16 +1,16 @@
'''
"""
torchrun --standalone --nproc_per_node=1 debug.py
'''
"""
from diffusers import AutoencoderKL
import colossalai
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx
from colossalai.zero import ColoInitContext
path = "/data/scratch/diffuser/stable-diffusion-v1-4"
colossalai.launch_from_torch(config={})
with ColoInitContext(device='cpu'):
with ColoInitContext(device="cpu"):
vae = AutoencoderKL.from_pretrained(
path,
subfolder="vae",

View File

@@ -1,7 +1,7 @@
from diffusers import StableDiffusionPipeline, DiffusionPipeline
import torch
from diffusers import DiffusionPipeline
model_id = <Your Model Path>
model_id = "<Your Model Path>"
print(f"Loading model... from{model_id}")
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")

View File

@@ -104,8 +104,10 @@ def parse_args(input_args=None):
"--num_class_images",
type=int,
default=100,
help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."),
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
@@ -118,17 +120,18 @@ def parse_args(input_args=None):
"--resolution",
type=int,
default=512,
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"),
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
)
parser.add_argument("--center_crop",
action="store_true",
help="Whether to center crop images before resizing to resolution")
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
parser.add_argument("--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.")
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
@@ -165,16 +168,17 @@ def parse_args(input_args=None):
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--lr_warmup_steps",
type=int,
default=500,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
@@ -192,8 +196,10 @@ def parse_args(input_args=None):
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
@@ -203,7 +209,8 @@ def parse_args(input_args=None):
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@@ -269,12 +276,14 @@ class DreamBoothDataset(Dataset):
else:
self.class_data_root = None
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
@@ -350,7 +359,8 @@ def main(args):
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.")
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
if args.seed is not None:
set_seed(args.seed)
@@ -380,9 +390,9 @@ def main(args):
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
for example in tqdm(sample_dataloader,
desc="Generating class images",
disable=not accelerator.is_local_main_process):
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
@@ -456,8 +466,9 @@ def main(args):
text_encoder.gradient_checkpointing_enable()
if args.scale_lr:
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size *
accelerator.num_processes)
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
@@ -470,8 +481,9 @@ def main(args):
else:
optimizer_class = torch.optim.AdamW
params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder else unet.parameters())
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
@@ -506,9 +518,7 @@ def main(args):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad(
{
"input_ids": input_ids
},
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
@@ -520,11 +530,9 @@ def main(args):
}
return batch
train_dataloader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=1)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -542,10 +550,12 @@ def main(args):
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader,
lr_scheduler)
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
@@ -641,8 +651,11 @@ def main(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder else unet.parameters())
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()

View File

@@ -117,8 +117,10 @@ def parse_args(input_args=None):
"--num_class_images",
type=int,
default=100,
help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."),
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
@@ -131,8 +133,10 @@ def parse_args(input_args=None):
"--resolution",
type=int,
default=512,
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"),
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--offload_optim_frac",
@@ -144,13 +148,14 @@ def parse_args(input_args=None):
"--center_crop",
default=False,
action="store_true",
help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."),
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.")
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
@@ -181,16 +186,17 @@ def parse_args(input_args=None):
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--lr_warmup_steps",
type=int,
default=500,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
@@ -202,18 +208,22 @@ def parse_args(input_args=None):
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument('-p',
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
help="plugin to use")
parser.add_argument(
"-p",
"--plugin",
type=str,
default="torch_ddp",
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"],
help="plugin to use",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
@@ -223,7 +233,8 @@ def parse_args(input_args=None):
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@@ -292,12 +303,14 @@ class DreamBoothDataset(Dataset):
else:
self.class_data_root = None
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
@@ -391,9 +404,9 @@ def main(args):
pipeline.to(get_current_device())
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not local_rank == 0,
sample_dataloader,
desc="Generating class images",
disable=not local_rank == 0,
):
images = pipeline(example["prompt"]).images
@@ -460,15 +473,14 @@ def main(args):
if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
)
else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision,
low_cpu_mem_usage=False)
unet = UNet2DConditionModel.from_pretrained(
args.externel_unet_path, revision=args.revision, low_cpu_mem_usage=False
)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
@@ -482,36 +494,37 @@ def main(args):
# Use Booster API to use Gemini/Zero with ColossalAI
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
elif args.plugin == "gemini":
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
elif args.plugin == "low_level_zero":
plugin = LowLevelZeroPlugin(initial_scale=2**5)
booster = Booster(plugin=plugin, **booster_kwargs)
# config optimizer for colossalai zero
optimizer = HybridAdam(unet.parameters(),
lr=args.learning_rate,
initial_scale=2**5,
clipping_norm=args.max_grad_norm)
optimizer = HybridAdam(
unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm
)
# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# prepare dataset
logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
train_dataset = DreamBoothDataset(instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
test=args.test_run)
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
test=args.test_run,
)
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
@@ -527,9 +540,7 @@ def main(args):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad(
{
"input_ids": input_ids
},
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
@@ -541,11 +552,9 @@ def main(args):
}
return batch
train_dataloader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=1)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -664,7 +673,7 @@ def main(args):
logs = {
"loss": loss.detach().item(),
"lr": optimizer.param_groups[0]["lr"],
} # lr_scheduler.get_last_lr()[0]}
} # lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step % args.save_steps == 0:

View File

@@ -28,8 +28,6 @@ from colossalai.legacy.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model
disable_existing_loggers()
logger = get_dist_logger()
@@ -122,8 +120,10 @@ def parse_args(input_args=None):
"--num_class_images",
type=int,
default=100,
help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."),
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
@@ -136,8 +136,10 @@ def parse_args(input_args=None):
"--resolution",
type=int,
default=512,
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"),
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--placement",
@@ -149,13 +151,14 @@ def parse_args(input_args=None):
"--center_crop",
default=False,
action="store_true",
help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."),
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.")
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
@@ -186,16 +189,17 @@ def parse_args(input_args=None):
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--lr_warmup_steps",
type=int,
default=500,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
@@ -206,18 +210,22 @@ def parse_args(input_args=None):
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument('-p',
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
help="plugin to use")
parser.add_argument(
"-p",
"--plugin",
type=str,
default="torch_ddp",
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"],
help="plugin to use",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
@@ -227,7 +235,8 @@ def parse_args(input_args=None):
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@@ -293,12 +302,14 @@ class DreamBoothDataset(Dataset):
else:
self.class_data_root = None
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
@@ -392,9 +403,9 @@ def main(args):
pipeline.to(get_current_device())
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not local_rank == 0,
sample_dataloader,
desc="Generating class images",
disable=not local_rank == 0,
):
images = pipeline(example["prompt"]).images
@@ -461,19 +472,17 @@ def main(args):
if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
)
else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision,
low_cpu_mem_usage=False)
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
unet = UNet2DConditionModel.from_pretrained(
args.externel_unet_path, revision=args.revision, low_cpu_mem_usage=False
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
)
unet.requires_grad_(False)
# Set correct lora layers
@@ -492,7 +501,7 @@ def main(args):
lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
AttnProcsLayers(unet.attn_processors)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
@@ -506,22 +515,21 @@ def main(args):
# Use Booster API to use Gemini/Zero with ColossalAI
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
elif args.plugin == "gemini":
plugin = GeminiPlugin(strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
elif args.plugin == "low_level_zero":
plugin = LowLevelZeroPlugin(initial_scale=2**5)
booster = Booster(plugin=plugin, **booster_kwargs)
# config optimizer for colossalai zero
optimizer = HybridAdam(unet.parameters(),
lr=args.learning_rate,
initial_scale=2**5,
clipping_norm=args.max_grad_norm)
optimizer = HybridAdam(
unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm
)
# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
@@ -552,9 +560,7 @@ def main(args):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad(
{
"input_ids": input_ids
},
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
@@ -566,11 +572,9 @@ def main(args):
}
return batch
train_dataloader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=1)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -689,7 +693,7 @@ def main(args):
logs = {
"loss": loss.detach().item(),
"lr": optimizer.param_groups[0]["lr"],
} # lr_scheduler.get_last_lr()[0]}
} # lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step % args.save_steps == 0:

View File

@@ -126,8 +126,10 @@ def parse_args():
"--num_class_images",
type=int,
default=100,
help=("Minimal class images for prior preservation loss. If not have enough images, additional images will be"
" sampled with class_prompt."),
help=(
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
" sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
@@ -140,17 +142,18 @@ def parse_args():
"--resolution",
type=int,
default=512,
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"),
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
)
parser.add_argument("--center_crop",
action="store_true",
help="Whether to center crop images before resizing to resolution")
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
parser.add_argument("--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.")
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
@@ -186,16 +189,17 @@ def parse_args():
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--lr_warmup_steps",
type=int,
default=500,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
@@ -213,17 +217,21 @@ def parse_args():
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=("Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."),
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@@ -283,12 +291,14 @@ class DreamBoothDataset(Dataset):
else:
self.class_data_root = None
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
@@ -369,7 +379,8 @@ def main():
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.")
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
if args.seed is not None:
set_seed(args.seed)
@@ -382,25 +393,25 @@ def main():
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
pipeline = StableDiffusionInpaintPipeline.from_pretrained(args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None)
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
)
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset,
batch_size=args.sample_batch_size,
num_workers=1)
sample_dataloader = torch.utils.data.DataLoader(
sample_dataset, batch_size=args.sample_batch_size, num_workers=1
)
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
transform_to_pil = transforms.ToPILImage()
for example in tqdm(sample_dataloader,
desc="Generating class images",
disable=not accelerator.is_local_main_process):
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
bsz = len(example["prompt"])
fake_images = torch.rand((3, args.resolution, args.resolution))
transform_to_pil = transforms.ToPILImage()
@@ -457,8 +468,9 @@ def main():
text_encoder.gradient_checkpointing_enable()
if args.scale_lr:
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size *
accelerator.num_processes)
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
@@ -471,8 +483,9 @@ def main():
else:
optimizer_class = torch.optim.AdamW
params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder else unet.parameters())
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
@@ -494,10 +507,12 @@ def main():
)
def collate_fn(examples):
image_transforms = transforms.Compose([
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
])
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
]
)
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
@@ -545,10 +560,9 @@ def main():
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
return batch
train_dataloader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=collate_fn)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -566,10 +580,12 @@ def main():
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader,
lr_scheduler)
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
@@ -622,16 +638,19 @@ def main():
latents = latents * 0.18215
# Convert masked images to latent space
masked_latents = vae.encode(batch["masked_images"].reshape(
batch["pixel_values"].shape).to(dtype=weight_dtype)).latent_dist.sample()
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * 0.18215
masks = batch["masks"]
# resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.stack([
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks
])
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
# Sample noise that we'll add to the latents
@@ -680,8 +699,11 @@ def main():
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder else unet.parameters())
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()