mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -126,8 +126,10 @@ def parse_args():
|
||||
"--num_class_images",
|
||||
type=int,
|
||||
default=100,
|
||||
help=("Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
||||
" sampled with class_prompt."),
|
||||
help=(
|
||||
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
||||
" sampled with class_prompt."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
@@ -140,17 +142,18 @@ def parse_args():
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"),
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||
)
|
||||
parser.add_argument("--center_crop",
|
||||
action="store_true",
|
||||
help="Whether to center crop images before resizing to resolution")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
|
||||
parser.add_argument("--train_batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size (per device) for the training dataloader.")
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
@@ -186,16 +189,17 @@ def parse_args():
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'),
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument("--lr_warmup_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps for the warmup in the lr scheduler.")
|
||||
parser.add_argument("--use_8bit_adam",
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit Adam from bitsandbytes.")
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
@@ -213,17 +217,21 @@ def parse_args():
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=("Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."),
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
@@ -283,12 +291,14 @@ class DreamBoothDataset(Dataset):
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
self.image_transforms = transforms.Compose([
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
])
|
||||
self.image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
@@ -369,7 +379,8 @@ def main():
|
||||
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
||||
raise ValueError(
|
||||
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.")
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
@@ -382,25 +393,25 @@ def main():
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
pipeline = StableDiffusionInpaintPipeline.from_pretrained(args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None)
|
||||
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
num_new_images = args.num_class_images - cur_class_images
|
||||
logger.info(f"Number of class images to sample: {num_new_images}.")
|
||||
|
||||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset,
|
||||
batch_size=args.sample_batch_size,
|
||||
num_workers=1)
|
||||
sample_dataloader = torch.utils.data.DataLoader(
|
||||
sample_dataset, batch_size=args.sample_batch_size, num_workers=1
|
||||
)
|
||||
|
||||
sample_dataloader = accelerator.prepare(sample_dataloader)
|
||||
pipeline.to(accelerator.device)
|
||||
transform_to_pil = transforms.ToPILImage()
|
||||
for example in tqdm(sample_dataloader,
|
||||
desc="Generating class images",
|
||||
disable=not accelerator.is_local_main_process):
|
||||
for example in tqdm(
|
||||
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
||||
):
|
||||
bsz = len(example["prompt"])
|
||||
fake_images = torch.rand((3, args.resolution, args.resolution))
|
||||
transform_to_pil = transforms.ToPILImage()
|
||||
@@ -457,8 +468,9 @@ def main():
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size *
|
||||
accelerator.num_processes)
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
@@ -471,8 +483,9 @@ def main():
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder else unet.parameters())
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
|
||||
)
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
@@ -494,10 +507,12 @@ def main():
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
image_transforms = transforms.Compose([
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
])
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
]
|
||||
)
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
|
||||
@@ -545,10 +560,9 @@ def main():
|
||||
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
|
||||
return batch
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -566,10 +580,12 @@ def main():
|
||||
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader,
|
||||
lr_scheduler)
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
@@ -622,16 +638,19 @@ def main():
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(batch["masked_images"].reshape(
|
||||
batch["pixel_values"].shape).to(dtype=weight_dtype)).latent_dist.sample()
|
||||
masked_latents = vae.encode(
|
||||
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * 0.18215
|
||||
|
||||
masks = batch["masks"]
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
mask = torch.stack([
|
||||
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
|
||||
for mask in masks
|
||||
])
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
@@ -680,8 +699,11 @@ def main():
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder else unet.parameters())
|
||||
params_to_clip = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet.parameters()
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
Reference in New Issue
Block a user