mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -117,8 +117,10 @@ def parse_args(input_args=None):
|
||||
"--num_class_images",
|
||||
type=int,
|
||||
default=100,
|
||||
help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
|
||||
" class_data_dir, additional images will be sampled with class_prompt."),
|
||||
help=(
|
||||
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
||||
" class_data_dir, additional images will be sampled with class_prompt."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
@@ -131,8 +133,10 @@ def parse_args(input_args=None):
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"),
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_optim_frac",
|
||||
@@ -144,13 +148,14 @@ def parse_args(input_args=None):
|
||||
"--center_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."),
|
||||
help=(
|
||||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--train_batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size (per device) for the training dataloader.")
|
||||
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
@@ -181,16 +186,17 @@ def parse_args(input_args=None):
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'),
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument("--lr_warmup_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps for the warmup in the lr scheduler.")
|
||||
parser.add_argument("--use_8bit_adam",
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit Adam from bitsandbytes.")
|
||||
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
@@ -202,18 +208,22 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
type=str,
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||
help="plugin to use")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="torch_ddp",
|
||||
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"],
|
||||
help="plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
@@ -223,7 +233,8 @@ def parse_args(input_args=None):
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
@@ -292,12 +303,14 @@ class DreamBoothDataset(Dataset):
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
self.image_transforms = transforms.Compose([
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
])
|
||||
self.image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
@@ -391,9 +404,9 @@ def main(args):
|
||||
pipeline.to(get_current_device())
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader,
|
||||
desc="Generating class images",
|
||||
disable=not local_rank == 0,
|
||||
sample_dataloader,
|
||||
desc="Generating class images",
|
||||
disable=not local_rank == 0,
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
@@ -460,15 +473,14 @@ def main(args):
|
||||
|
||||
if args.externel_unet_path is None:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
|
||||
)
|
||||
else:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.externel_unet_path, revision=args.revision, low_cpu_mem_usage=False
|
||||
)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -482,36 +494,37 @@ def main(args):
|
||||
# Use Booster API to use Gemini/Zero with ColossalAI
|
||||
|
||||
booster_kwargs = {}
|
||||
if args.plugin == 'torch_ddp_fp16':
|
||||
booster_kwargs['mixed_precision'] = 'fp16'
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
if args.plugin == "torch_ddp_fp16":
|
||||
booster_kwargs["mixed_precision"] = "fp16"
|
||||
if args.plugin.startswith("torch_ddp"):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
elif args.plugin == "low_level_zero":
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
# config optimizer for colossalai zero
|
||||
optimizer = HybridAdam(unet.parameters(),
|
||||
lr=args.learning_rate,
|
||||
initial_scale=2**5,
|
||||
clipping_norm=args.max_grad_norm)
|
||||
optimizer = HybridAdam(
|
||||
unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm
|
||||
)
|
||||
|
||||
# load noise_scheduler
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
# prepare dataset
|
||||
logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
|
||||
train_dataset = DreamBoothDataset(instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_prompt=args.class_prompt,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
test=args.test_run)
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_prompt=args.class_prompt,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
test=args.test_run,
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
@@ -527,9 +540,7 @@ def main(args):
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
input_ids = tokenizer.pad(
|
||||
{
|
||||
"input_ids": input_ids
|
||||
},
|
||||
{"input_ids": input_ids},
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
@@ -541,11 +552,9 @@ def main(args):
|
||||
}
|
||||
return batch
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=1)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -664,7 +673,7 @@ def main(args):
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
} # lr_scheduler.get_last_lr()[0]}
|
||||
} # lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
|
Reference in New Issue
Block a user