mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
change directory
This commit is contained in:
@@ -4,6 +4,7 @@ import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -21,9 +22,12 @@ import colossalai
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.zero.gemini import get_static_torch_model
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
@@ -58,6 +62,13 @@ def parse_args(input_args=None):
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--externel_unet_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to the externel unet model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
@@ -193,6 +204,12 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
type=str,
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||
help="plugin to use")
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
@@ -339,18 +356,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
search_range_mb=64)
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.seed is None:
|
||||
colossalai.launch_from_torch(config={})
|
||||
@@ -392,7 +397,7 @@ def main(args):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
hash_image = hashlib.sha256(image.tobytes()).hexdigest()
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
@@ -452,12 +457,18 @@ def main(args):
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
|
||||
if args.externel_unet_path is None:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
else:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -468,10 +479,22 @@ def main(args):
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
|
||||
|
||||
unet = gemini_zero_dpp(unet, args.placement)
|
||||
# Use Booster API to use Gemini/Zero with ColossalAI
|
||||
|
||||
booster_kwargs = {}
|
||||
if args.plugin == 'torch_ddp_fp16':
|
||||
booster_kwargs['mixed_precision'] = 'fp16'
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
# config optimizer for colossalai zero
|
||||
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||
|
||||
# load noise_scheduler
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
@@ -554,6 +577,8 @@ def main(args):
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * world_size
|
||||
|
||||
@@ -642,36 +667,24 @@ def main(args):
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
torch.cuda.synchronize()
|
||||
torch_unet = get_static_torch_model(unet)
|
||||
if local_rank == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=torch_unet,
|
||||
revision=args.revision,
|
||||
)
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
pipeline.save_pretrained(save_path)
|
||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
torch.cuda.synchronize()
|
||||
unet = get_static_torch_model(unet)
|
||||
|
||||
booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
|
||||
logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
|
||||
if local_rank == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
||||
|
||||
if not os.path.exists(os.path.join(args.output_dir, "config.json")):
|
||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user