diff --git a/examples/images/dreambooth/README.md b/examples/images/dreambooth/README.md index 5b350bc95..ba4c1a710 100644 --- a/examples/images/dreambooth/README.md +++ b/examples/images/dreambooth/README.md @@ -92,6 +92,29 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \ --placement="cuda" ``` +## New API +We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`. +We have also offer a shell script `test_ci.sh` for you to go through all our plugins for the booster. +For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/. + +## Performance + +| Strategy | #GPU | Batch Size | GPU RAM(GB) | speedup | +|:--------------:|:----:|:----------:|:-----------:|:-------:| +| Traditional | 1 | 16 | oom | \ | +| Traditional | 1 | 8 | 61.81 | 1 | +| torch_ddp | 4 | 16 | oom | \ | +| torch_ddp | 4 | 8 | 41.97 | 0.97 | +| gemini | 4 | 16 | 53.29 | \ | +| gemini | 4 | 8 | 29.36 | 2.00 | +| low_level_zero | 4 | 16 | 52.80 | \ | +| low_level_zero | 4 | 8 | 28.87 | 2.02 | + +The evaluation is performed on 4 Nvidia A100 GPUs with 80GB memory each, with GPU 0 & 1, 2 & 3 connected with NVLink. +We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared +the memory cost and the throughput for the plugins. + + ## Inference Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt. diff --git a/examples/images/dreambooth/colossalai.sh b/examples/images/dreambooth/colossalai.sh index 227d8b8bd..db4562dbc 100755 --- a/examples/images/dreambooth/colossalai.sh +++ b/examples/images/dreambooth/colossalai.sh @@ -1,22 +1,18 @@ -export MODEL_NAME= -export INSTANCE_DIR= -export CLASS_DIR="path-to-class-images" -export OUTPUT_DIR="path-to-save-model" - -HF_DATASETS_OFFLINE=1 -TRANSFORMERS_OFFLINE=1 +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 DIFFUSERS_OFFLINE=1 -torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --output_dir=$OUTPUT_DIR \ - --instance_prompt="a photo of a dog" \ +torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \ + --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \ + --instance_data_dir="/data/dreambooth/Teyvat/data" \ + --output_dir="./weight_output" \ + --instance_prompt="a picture of a dog" \ --resolution=512 \ + --plugin="gemini" \ --train_batch_size=1 \ - --gradient_accumulation_steps=1 \ --learning_rate=5e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --placement="cuda" \ + --test_run=True \ + --placement="auto" \ diff --git a/examples/images/dreambooth/dreambooth.sh b/examples/images/dreambooth/dreambooth.sh index e063bc827..f6b8f5e1b 100644 --- a/examples/images/dreambooth/dreambooth.sh +++ b/examples/images/dreambooth/dreambooth.sh @@ -1,7 +1,7 @@ python train_dreambooth.py \ - --pretrained_model_name_or_path= ## Your Model Path \ - --instance_data_dir= ## Your Training Input Pics Path \ - --output_dir="path-to-save-model" \ + --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \ + --instance_data_dir="/data/dreambooth/Teyvat/data" \ + --output_dir="./weight_output" \ --instance_prompt="a photo of a dog" \ --resolution=512 \ --train_batch_size=1 \ diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh index e69de29bb..21f45adae 100644 --- a/examples/images/dreambooth/test_ci.sh +++ b/examples/images/dreambooth/test_ci.sh @@ -0,0 +1,25 @@ +#!/bin/bash +set -xe +pip install -r requirements.txt + +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 + +# "torch_ddp" "torch_ddp_fp16" "low_level_zero" +for plugin in "gemini"; do + torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \ + --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \ + --instance_data_dir="/data/dreambooth/Teyvat/data" \ + --output_dir="./weight_output" \ + --instance_prompt="a picture of a dog" \ + --resolution=512 \ + --plugin=$plugin \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --test_run=True \ + --num_class_images=200 \ + --placement="auto" # "cuda" +done diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index d07febea0..888b28de8 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -4,6 +4,7 @@ import math import os from pathlib import Path from typing import Optional +import shutil import torch import torch.nn.functional as F @@ -21,9 +22,12 @@ import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer +from colossalai.zero import ColoInitContext from colossalai.zero.gemini import get_static_torch_model +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin disable_existing_loggers() logger = get_dist_logger() @@ -58,6 +62,13 @@ def parse_args(input_args=None): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--externel_unet_path", + type=str, + default=None, + required=False, + help="Path to the externel unet model.", + ) parser.add_argument( "--revision", type=str, @@ -187,12 +198,19 @@ def parse_args(input_args=None): parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + help="plugin to use") parser.add_argument( "--logging_dir", type=str, @@ -250,6 +268,7 @@ class DreamBoothDataset(Dataset): class_prompt=None, size=512, center_crop=False, + test=False, ): self.size = size self.center_crop = center_crop @@ -260,6 +279,8 @@ class DreamBoothDataset(Dataset): raise ValueError("Instance images root doesn't exists.") self.instance_images_path = list(Path(instance_data_root).iterdir()) + if test: + self.instance_images_path = self.instance_images_path[:10] self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images @@ -339,18 +360,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -# Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"): - from colossalai.nn.parallel import GeminiDDP - - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placement_policy, - pin_memory=True, - search_range_mb=64) - return model - - def main(args): if args.seed is None: colossalai.launch_from_torch(config={}) @@ -392,7 +401,7 @@ def main(args): images = pipeline(example["prompt"]).images for i, image in enumerate(images): - hash_image = hashlib.sha1(image.tobytes()).hexdigest() + hash_image = hashlib.sha256(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) @@ -452,12 +461,18 @@ def main(args): revision=args.revision, ) - logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) - with ColoInitContext(device=get_current_device()): + + if args.externel_unet_path is None: + logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) + subfolder="unet", + revision=args.revision, + low_cpu_mem_usage=False) + else: + logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) + unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, + revision=args.revision, + low_cpu_mem_usage=False) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -468,10 +483,22 @@ def main(args): if args.scale_lr: args.learning_rate = args.learning_rate * args.train_batch_size * world_size - unet = gemini_zero_dpp(unet, args.placement) + # Use Booster API to use Gemini/Zero with ColossalAI + + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) + + booster = Booster(plugin=plugin, **booster_kwargs) # config optimizer for colossalai zero - optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) + optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") @@ -486,6 +513,7 @@ def main(args): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, + test=args.test_run ) def collate_fn(examples): @@ -554,6 +582,8 @@ def main(args): # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler) + # Train! total_batch_size = args.train_batch_size * world_size @@ -642,36 +672,24 @@ def main(args): if global_step % args.save_steps == 0: torch.cuda.synchronize() - torch_unet = get_static_torch_model(unet) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if local_rank == 0: - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=torch_unet, - revision=args.revision, - ) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - pipeline.save_pretrained(save_path) + if not os.path.exists(os.path.join(save_path, "config.json")): + shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) if global_step >= args.max_train_steps: break - torch.cuda.synchronize() - unet = get_static_torch_model(unet) + booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin")) + logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}") if local_rank == 0: - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=unet, - revision=args.revision, - ) - - pipeline.save_pretrained(args.output_dir) - logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) - + if not os.path.exists(os.path.join(args.output_dir, "config.json")): + shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) - if __name__ == "__main__": args = parse_args() main(args) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index 6715b473a..dce65ff51 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -4,6 +4,7 @@ import math import os from pathlib import Path from typing import Optional +import shutil import torch import torch.nn.functional as F @@ -23,9 +24,12 @@ import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, GeminiAdamOptimizer from colossalai.zero.gemini import get_static_torch_model +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin disable_existing_loggers() logger = get_dist_logger() @@ -60,6 +64,13 @@ def parse_args(input_args=None): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--externel_unet_path", + type=str, + default=None, + required=False, + help="Path to the externel unet model.", + ) parser.add_argument( "--revision", type=str, @@ -195,6 +206,12 @@ def parse_args(input_args=None): default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + help="plugin to use") parser.add_argument( "--logging_dir", type=str, @@ -341,18 +358,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -# Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"): - from colossalai.nn.parallel import GeminiDDP - - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placement_policy, - pin_memory=True, - search_range_mb=64) - return model - - def main(args): if args.seed is None: colossalai.launch_from_torch(config={}) @@ -394,7 +399,7 @@ def main(args): images = pipeline(example["prompt"]).images for i, image in enumerate(images): - hash_image = hashlib.sha1(image.tobytes()).hexdigest() + hash_image = hashlib.sha256(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) @@ -454,32 +459,42 @@ def main(args): revision=args.revision, ) - logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) - with ColoInitContext(device=get_current_device()): + + if args.externel_unet_path is None: + logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) - unet.requires_grad_(False) + subfolder="unet", + revision=args.revision, + low_cpu_mem_usage=False) + else: + logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) + unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, + revision=args.revision, + low_cpu_mem_usage=False) + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + low_cpu_mem_usage=False) + unet.requires_grad_(False) - # Set correct lora layers - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] + # Set correct lora layers + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim) - unet.set_attn_processor(lora_attn_procs) - lora_layers = AttnProcsLayers(unet.attn_processors) + unet.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(unet.attn_processors) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -490,10 +505,22 @@ def main(args): if args.scale_lr: args.learning_rate = args.learning_rate * args.train_batch_size * world_size - unet = gemini_zero_dpp(unet, args.placement) + # Use Booster API to use Gemini/Zero with ColossalAI + + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) + + booster = Booster(plugin=plugin, **booster_kwargs) # config optimizer for colossalai zero - optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) + optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") @@ -576,6 +603,8 @@ def main(args): # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler) + # Train! total_batch_size = args.train_batch_size * world_size @@ -664,27 +693,24 @@ def main(args): if global_step % args.save_steps == 0: torch.cuda.synchronize() - torch_unet = get_static_torch_model(unet) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if local_rank == 0: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - torch_unet = torch_unet.to(torch.float32) - torch_unet.save_attn_procs(save_path) + if not os.path.exists(os.path.join(save_path, "config.json")): + shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) if global_step >= args.max_train_steps: break - torch.cuda.synchronize() - torch_unet = get_static_torch_model(unet) + booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin")) + logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}") if local_rank == 0: - torch_unet = torch_unet.to(torch.float32) - torch_unet.save_attn_procs(save_path) - logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) - + if not os.path.exists(os.path.join(args.output_dir, "config.json")): + shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) - if __name__ == "__main__": args = parse_args() main(args)