mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 07:47:05 +00:00
Merge pull request #3905 from MaruyamaAya/dreambooth
[example] Adding an example of training dreambooth with the new booster API
This commit is contained in:
commit
e277534a18
@ -92,6 +92,29 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
|
|||||||
--placement="cuda"
|
--placement="cuda"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## New API
|
||||||
|
We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`.
|
||||||
|
We have also offer a shell script `test_ci.sh` for you to go through all our plugins for the booster.
|
||||||
|
For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/.
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
| Strategy | #GPU | Batch Size | GPU RAM(GB) | speedup |
|
||||||
|
|:--------------:|:----:|:----------:|:-----------:|:-------:|
|
||||||
|
| Traditional | 1 | 16 | oom | \ |
|
||||||
|
| Traditional | 1 | 8 | 61.81 | 1 |
|
||||||
|
| torch_ddp | 4 | 16 | oom | \ |
|
||||||
|
| torch_ddp | 4 | 8 | 41.97 | 0.97 |
|
||||||
|
| gemini | 4 | 16 | 53.29 | \ |
|
||||||
|
| gemini | 4 | 8 | 29.36 | 2.00 |
|
||||||
|
| low_level_zero | 4 | 16 | 52.80 | \ |
|
||||||
|
| low_level_zero | 4 | 8 | 28.87 | 2.02 |
|
||||||
|
|
||||||
|
The evaluation is performed on 4 Nvidia A100 GPUs with 80GB memory each, with GPU 0 & 1, 2 & 3 connected with NVLink.
|
||||||
|
We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared
|
||||||
|
the memory cost and the throughput for the plugins.
|
||||||
|
|
||||||
|
|
||||||
## Inference
|
## Inference
|
||||||
|
|
||||||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt.
|
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt.
|
||||||
|
@ -1,22 +1,18 @@
|
|||||||
export MODEL_NAME= <Your Pretrained Model Path>
|
|
||||||
export INSTANCE_DIR= <Your Input Pics Path>
|
|
||||||
export CLASS_DIR="path-to-class-images"
|
|
||||||
export OUTPUT_DIR="path-to-save-model"
|
|
||||||
|
|
||||||
HF_DATASETS_OFFLINE=1
|
HF_DATASETS_OFFLINE=1
|
||||||
TRANSFORMERS_OFFLINE=1
|
TRANSFORMERS_OFFLINE=1
|
||||||
DIFFUSERS_OFFLINE=1
|
DIFFUSERS_OFFLINE=1
|
||||||
|
|
||||||
torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \
|
torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
|
||||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
--pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
|
||||||
--instance_data_dir=$INSTANCE_DIR \
|
--instance_data_dir="/data/dreambooth/Teyvat/data" \
|
||||||
--output_dir=$OUTPUT_DIR \
|
--output_dir="./weight_output" \
|
||||||
--instance_prompt="a photo of a dog" \
|
--instance_prompt="a picture of a dog" \
|
||||||
--resolution=512 \
|
--resolution=512 \
|
||||||
|
--plugin="gemini" \
|
||||||
--train_batch_size=1 \
|
--train_batch_size=1 \
|
||||||
--gradient_accumulation_steps=1 \
|
|
||||||
--learning_rate=5e-6 \
|
--learning_rate=5e-6 \
|
||||||
--lr_scheduler="constant" \
|
--lr_scheduler="constant" \
|
||||||
--lr_warmup_steps=0 \
|
--lr_warmup_steps=0 \
|
||||||
--num_class_images=200 \
|
--num_class_images=200 \
|
||||||
--placement="cuda" \
|
--test_run=True \
|
||||||
|
--placement="auto" \
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
python train_dreambooth.py \
|
python train_dreambooth.py \
|
||||||
--pretrained_model_name_or_path= ## Your Model Path \
|
--pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
|
||||||
--instance_data_dir= ## Your Training Input Pics Path \
|
--instance_data_dir="/data/dreambooth/Teyvat/data" \
|
||||||
--output_dir="path-to-save-model" \
|
--output_dir="./weight_output" \
|
||||||
--instance_prompt="a photo of a dog" \
|
--instance_prompt="a photo of a dog" \
|
||||||
--resolution=512 \
|
--resolution=512 \
|
||||||
--train_batch_size=1 \
|
--train_batch_size=1 \
|
||||||
|
@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -xe
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
HF_DATASETS_OFFLINE=1
|
||||||
|
TRANSFORMERS_OFFLINE=1
|
||||||
|
DIFFUSERS_OFFLINE=1
|
||||||
|
|
||||||
|
# "torch_ddp" "torch_ddp_fp16" "low_level_zero"
|
||||||
|
for plugin in "gemini"; do
|
||||||
|
torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
|
||||||
|
--pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
|
||||||
|
--instance_data_dir="/data/dreambooth/Teyvat/data" \
|
||||||
|
--output_dir="./weight_output" \
|
||||||
|
--instance_prompt="a picture of a dog" \
|
||||||
|
--resolution=512 \
|
||||||
|
--plugin=$plugin \
|
||||||
|
--train_batch_size=1 \
|
||||||
|
--learning_rate=5e-6 \
|
||||||
|
--lr_scheduler="constant" \
|
||||||
|
--lr_warmup_steps=0 \
|
||||||
|
--test_run=True \
|
||||||
|
--num_class_images=200 \
|
||||||
|
--placement="auto" # "cuda"
|
||||||
|
done
|
@ -4,6 +4,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import shutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -21,9 +22,12 @@ import colossalai
|
|||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
|
from colossalai.zero import ColoInitContext
|
||||||
from colossalai.zero.gemini import get_static_torch_model
|
from colossalai.zero.gemini import get_static_torch_model
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
@ -58,6 +62,13 @@ def parse_args(input_args=None):
|
|||||||
required=True,
|
required=True,
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--externel_unet_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="Path to the externel unet model.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--revision",
|
"--revision",
|
||||||
type=str,
|
type=str,
|
||||||
@ -187,12 +198,19 @@ def parse_args(input_args=None):
|
|||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||||
|
parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hub_model_id",
|
"--hub_model_id",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument('-p',
|
||||||
|
'--plugin',
|
||||||
|
type=str,
|
||||||
|
default='torch_ddp',
|
||||||
|
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||||
|
help="plugin to use")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging_dir",
|
"--logging_dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -250,6 +268,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
class_prompt=None,
|
class_prompt=None,
|
||||||
size=512,
|
size=512,
|
||||||
center_crop=False,
|
center_crop=False,
|
||||||
|
test=False,
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.center_crop = center_crop
|
self.center_crop = center_crop
|
||||||
@ -260,6 +279,8 @@ class DreamBoothDataset(Dataset):
|
|||||||
raise ValueError("Instance images root doesn't exists.")
|
raise ValueError("Instance images root doesn't exists.")
|
||||||
|
|
||||||
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
||||||
|
if test:
|
||||||
|
self.instance_images_path = self.instance_images_path[:10]
|
||||||
self.num_instance_images = len(self.instance_images_path)
|
self.num_instance_images = len(self.instance_images_path)
|
||||||
self.instance_prompt = instance_prompt
|
self.instance_prompt = instance_prompt
|
||||||
self._length = self.num_instance_images
|
self._length = self.num_instance_images
|
||||||
@ -339,18 +360,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
|||||||
return f"{organization}/{model_id}"
|
return f"{organization}/{model_id}"
|
||||||
|
|
||||||
|
|
||||||
# Gemini + ZeRO DDP
|
|
||||||
def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
|
|
||||||
from colossalai.nn.parallel import GeminiDDP
|
|
||||||
|
|
||||||
model = GeminiDDP(model,
|
|
||||||
device=get_current_device(),
|
|
||||||
placement_policy=placement_policy,
|
|
||||||
pin_memory=True,
|
|
||||||
search_range_mb=64)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
if args.seed is None:
|
if args.seed is None:
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch(config={})
|
||||||
@ -392,7 +401,7 @@ def main(args):
|
|||||||
images = pipeline(example["prompt"]).images
|
images = pipeline(example["prompt"]).images
|
||||||
|
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
hash_image = hashlib.sha256(image.tobytes()).hexdigest()
|
||||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||||
image.save(image_filename)
|
image.save(image_filename)
|
||||||
|
|
||||||
@ -452,12 +461,18 @@ def main(args):
|
|||||||
revision=args.revision,
|
revision=args.revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if args.externel_unet_path is None:
|
||||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||||
with ColoInitContext(device=get_current_device()):
|
|
||||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||||
subfolder="unet",
|
subfolder="unet",
|
||||||
revision=args.revision,
|
revision=args.revision,
|
||||||
low_cpu_mem_usage=False)
|
low_cpu_mem_usage=False)
|
||||||
|
else:
|
||||||
|
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
|
||||||
|
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
|
||||||
|
revision=args.revision,
|
||||||
|
low_cpu_mem_usage=False)
|
||||||
|
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
@ -468,10 +483,22 @@ def main(args):
|
|||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
|
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
|
||||||
|
|
||||||
unet = gemini_zero_dpp(unet, args.placement)
|
# Use Booster API to use Gemini/Zero with ColossalAI
|
||||||
|
|
||||||
|
booster_kwargs = {}
|
||||||
|
if args.plugin == 'torch_ddp_fp16':
|
||||||
|
booster_kwargs['mixed_precision'] = 'fp16'
|
||||||
|
if args.plugin.startswith('torch_ddp'):
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
elif args.plugin == 'gemini':
|
||||||
|
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||||
|
elif args.plugin == 'low_level_zero':
|
||||||
|
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
|
||||||
# config optimizer for colossalai zero
|
# config optimizer for colossalai zero
|
||||||
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||||
|
|
||||||
# load noise_scheduler
|
# load noise_scheduler
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||||
@ -486,6 +513,7 @@ def main(args):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
size=args.resolution,
|
size=args.resolution,
|
||||||
center_crop=args.center_crop,
|
center_crop=args.center_crop,
|
||||||
|
test=args.test_run
|
||||||
)
|
)
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
@ -554,6 +582,8 @@ def main(args):
|
|||||||
# Afterwards we recalculate our number of training epochs
|
# Afterwards we recalculate our number of training epochs
|
||||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
total_batch_size = args.train_batch_size * world_size
|
total_batch_size = args.train_batch_size * world_size
|
||||||
|
|
||||||
@ -642,36 +672,24 @@ def main(args):
|
|||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch_unet = get_static_torch_model(unet)
|
|
||||||
if local_rank == 0:
|
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
|
||||||
args.pretrained_model_name_or_path,
|
|
||||||
unet=torch_unet,
|
|
||||||
revision=args.revision,
|
|
||||||
)
|
|
||||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||||
pipeline.save_pretrained(save_path)
|
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||||
|
if local_rank == 0:
|
||||||
|
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||||
|
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
unet = get_static_torch_model(unet)
|
|
||||||
|
|
||||||
|
booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
|
||||||
|
logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
if not os.path.exists(os.path.join(args.output_dir, "config.json")):
|
||||||
args.pretrained_model_name_or_path,
|
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
|
||||||
unet=unet,
|
|
||||||
revision=args.revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline.save_pretrained(args.output_dir)
|
|
||||||
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
|
||||||
|
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -4,6 +4,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import shutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -23,9 +24,12 @@ import colossalai
|
|||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
|
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
|
||||||
from colossalai.zero.gemini import get_static_torch_model
|
from colossalai.zero.gemini import get_static_torch_model
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
@ -60,6 +64,13 @@ def parse_args(input_args=None):
|
|||||||
required=True,
|
required=True,
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--externel_unet_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="Path to the externel unet model.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--revision",
|
"--revision",
|
||||||
type=str,
|
type=str,
|
||||||
@ -195,6 +206,12 @@ def parse_args(input_args=None):
|
|||||||
default=None,
|
default=None,
|
||||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument('-p',
|
||||||
|
'--plugin',
|
||||||
|
type=str,
|
||||||
|
default='torch_ddp',
|
||||||
|
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||||
|
help="plugin to use")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging_dir",
|
"--logging_dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -341,18 +358,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
|||||||
return f"{organization}/{model_id}"
|
return f"{organization}/{model_id}"
|
||||||
|
|
||||||
|
|
||||||
# Gemini + ZeRO DDP
|
|
||||||
def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
|
|
||||||
from colossalai.nn.parallel import GeminiDDP
|
|
||||||
|
|
||||||
model = GeminiDDP(model,
|
|
||||||
device=get_current_device(),
|
|
||||||
placement_policy=placement_policy,
|
|
||||||
pin_memory=True,
|
|
||||||
search_range_mb=64)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
if args.seed is None:
|
if args.seed is None:
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch(config={})
|
||||||
@ -394,7 +399,7 @@ def main(args):
|
|||||||
images = pipeline(example["prompt"]).images
|
images = pipeline(example["prompt"]).images
|
||||||
|
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
hash_image = hashlib.sha256(image.tobytes()).hexdigest()
|
||||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||||
image.save(image_filename)
|
image.save(image_filename)
|
||||||
|
|
||||||
@ -454,8 +459,18 @@ def main(args):
|
|||||||
revision=args.revision,
|
revision=args.revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if args.externel_unet_path is None:
|
||||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||||
with ColoInitContext(device=get_current_device()):
|
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||||
|
subfolder="unet",
|
||||||
|
revision=args.revision,
|
||||||
|
low_cpu_mem_usage=False)
|
||||||
|
else:
|
||||||
|
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
|
||||||
|
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
|
||||||
|
revision=args.revision,
|
||||||
|
low_cpu_mem_usage=False)
|
||||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||||
subfolder="unet",
|
subfolder="unet",
|
||||||
revision=args.revision,
|
revision=args.revision,
|
||||||
@ -490,10 +505,22 @@ def main(args):
|
|||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
|
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
|
||||||
|
|
||||||
unet = gemini_zero_dpp(unet, args.placement)
|
# Use Booster API to use Gemini/Zero with ColossalAI
|
||||||
|
|
||||||
|
booster_kwargs = {}
|
||||||
|
if args.plugin == 'torch_ddp_fp16':
|
||||||
|
booster_kwargs['mixed_precision'] = 'fp16'
|
||||||
|
if args.plugin.startswith('torch_ddp'):
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
elif args.plugin == 'gemini':
|
||||||
|
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||||
|
elif args.plugin == 'low_level_zero':
|
||||||
|
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
|
||||||
# config optimizer for colossalai zero
|
# config optimizer for colossalai zero
|
||||||
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||||
|
|
||||||
# load noise_scheduler
|
# load noise_scheduler
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||||
@ -576,6 +603,8 @@ def main(args):
|
|||||||
# Afterwards we recalculate our number of training epochs
|
# Afterwards we recalculate our number of training epochs
|
||||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
total_batch_size = args.train_batch_size * world_size
|
total_batch_size = args.train_batch_size * world_size
|
||||||
|
|
||||||
@ -664,27 +693,24 @@ def main(args):
|
|||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch_unet = get_static_torch_model(unet)
|
|
||||||
if local_rank == 0:
|
|
||||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||||
torch_unet = torch_unet.to(torch.float32)
|
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||||
torch_unet.save_attn_procs(save_path)
|
if local_rank == 0:
|
||||||
|
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||||
|
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch_unet = get_static_torch_model(unet)
|
|
||||||
|
|
||||||
|
booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
|
||||||
|
logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
torch_unet = torch_unet.to(torch.float32)
|
if not os.path.exists(os.path.join(args.output_dir, "config.json")):
|
||||||
torch_unet.save_attn_procs(save_path)
|
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
|
||||||
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
|
||||||
|
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
Loading…
Reference in New Issue
Block a user