mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[gemini] add get static torch model (#2356)
This commit is contained in:
@@ -8,25 +8,23 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel.utils import convert_to_torch_module
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
@@ -112,10 +110,8 @@ def parse_args(input_args=None):
|
||||
"--num_class_images",
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
||||
" class_data_dir, additional images will be sampled with class_prompt."
|
||||
),
|
||||
help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
|
||||
" class_data_dir, additional images will be sampled with class_prompt."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
@@ -128,10 +124,8 @@ def parse_args(input_args=None):
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
@@ -139,15 +133,14 @@ def parse_args(input_args=None):
|
||||
default="cpu",
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
||||
)
|
||||
parser.add_argument("--center_crop",
|
||||
action="store_true",
|
||||
help="Whether to center crop images before resizing to resolution")
|
||||
parser.add_argument("--train_batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size (per device) for the training dataloader.")
|
||||
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
@@ -183,17 +176,16 @@ def parse_args(input_args=None):
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'),
|
||||
)
|
||||
parser.add_argument("--lr_warmup_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps for the warmup in the lr scheduler.")
|
||||
parser.add_argument("--use_8bit_adam",
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit Adam from bitsandbytes.")
|
||||
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
@@ -208,10 +200,8 @@ def parse_args(input_args=None):
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
@@ -221,8 +211,7 @@ def parse_args(input_args=None):
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
@@ -288,14 +277,12 @@ class DreamBoothDataset(Dataset):
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
self.image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
self.image_transforms = transforms.Compose([
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
@@ -356,26 +343,19 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
||||
def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
|
||||
model = GeminiDDP(
|
||||
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32
|
||||
)
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=placememt_policy,
|
||||
pin_memory=True,
|
||||
search_range_mb=64)
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
# config for colossalai
|
||||
|
||||
config = {
|
||||
"BATCH": args.train_batch_size,
|
||||
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||
"clip_grad_norm": args.max_grad_norm,
|
||||
}
|
||||
|
||||
colossalai.launch_from_torch(config=config)
|
||||
pg = ProcessGroup()
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
if args.seed is not None:
|
||||
gpc.set_seed(args.seed)
|
||||
@@ -405,9 +385,9 @@ def main(args):
|
||||
pipeline.to(get_current_device())
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader,
|
||||
desc="Generating class images",
|
||||
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
|
||||
sample_dataloader,
|
||||
desc="Generating class images",
|
||||
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
@@ -472,10 +452,11 @@ def main(args):
|
||||
)
|
||||
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
with ColoInitContext():
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
|
||||
)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -486,10 +467,10 @@ def main(args):
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
||||
|
||||
unet = gemini_zero_dpp(unet, pg, args.placement)
|
||||
unet = gemini_zero_dpp(unet, args.placement)
|
||||
|
||||
# config optimizer for colossalai zero
|
||||
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5)
|
||||
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||
|
||||
# load noise_scheduler
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
@@ -520,7 +501,9 @@ def main(args):
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
input_ids = tokenizer.pad(
|
||||
{"input_ids": input_ids},
|
||||
{
|
||||
"input_ids": input_ids
|
||||
},
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
@@ -532,9 +515,11 @@ def main(args):
|
||||
}
|
||||
return batch
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=1)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -652,15 +637,16 @@ def main(args):
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
} # lr_scheduler.get_last_lr()[0]}
|
||||
} # lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
torch.cuda.synchronize()
|
||||
torch_unet = get_static_torch_model(unet)
|
||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=convert_to_torch_module(unet),
|
||||
unet=torch_unet,
|
||||
revision=args.revision,
|
||||
)
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
@@ -670,7 +656,7 @@ def main(args):
|
||||
break
|
||||
|
||||
torch.cuda.synchronize()
|
||||
unet = convert_to_torch_module(unet)
|
||||
unet = get_static_torch_model(unet)
|
||||
|
||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
Reference in New Issue
Block a user