[gemini] add get static torch model (#2356)

This commit is contained in:
HELSON
2023-01-06 13:41:19 +08:00
committed by GitHub
parent 7a332b1734
commit 48d33b1b17
4 changed files with 164 additions and 118 deletions

View File

@@ -8,25 +8,23 @@ from typing import Optional
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel.utils import convert_to_torch_module
from colossalai.tensor import ProcessGroup
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
disable_existing_loggers()
logger = get_dist_logger()
@@ -112,10 +110,8 @@ def parse_args(input_args=None):
"--num_class_images",
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."),
)
parser.add_argument(
"--output_dir",
@@ -128,10 +124,8 @@ def parse_args(input_args=None):
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"),
)
parser.add_argument(
"--placement",
@@ -139,15 +133,14 @@ def parse_args(input_args=None):
default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--center_crop",
action="store_true",
help="Whether to center crop images before resizing to resolution")
parser.add_argument("--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.")
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
@@ -183,17 +176,16 @@ def parse_args(input_args=None):
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
)
parser.add_argument("--lr_warmup_steps",
type=int,
default=500,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
@@ -208,10 +200,8 @@ def parse_args(input_args=None):
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
)
parser.add_argument(
"--mixed_precision",
@@ -221,8 +211,7 @@ def parse_args(input_args=None):
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@@ -288,14 +277,12 @@ class DreamBoothDataset(Dataset):
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def __len__(self):
return self._length
@@ -356,26 +343,19 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32
)
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=64)
return model
def main(args):
# config for colossalai
config = {
"BATCH": args.train_batch_size,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"clip_grad_norm": args.max_grad_norm,
}
colossalai.launch_from_torch(config=config)
pg = ProcessGroup()
colossalai.launch_from_torch(config={})
if args.seed is not None:
gpc.set_seed(args.seed)
@@ -405,9 +385,9 @@ def main(args):
pipeline.to(get_current_device())
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
sample_dataloader,
desc="Generating class images",
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
):
images = pipeline(example["prompt"]).images
@@ -472,10 +452,11 @@ def main(args):
)
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
with ColoInitContext():
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
)
with ColoInitContext(device=get_current_device()):
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
@@ -486,10 +467,10 @@ def main(args):
if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
unet = gemini_zero_dpp(unet, pg, args.placement)
unet = gemini_zero_dpp(unet, args.placement)
# config optimizer for colossalai zero
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5)
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
@@ -520,7 +501,9 @@ def main(args):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad(
{"input_ids": input_ids},
{
"input_ids": input_ids
},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
@@ -532,9 +515,11 @@ def main(args):
}
return batch
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
)
train_dataloader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=1)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -652,15 +637,16 @@ def main(args):
logs = {
"loss": loss.detach().item(),
"lr": optimizer.param_groups[0]["lr"],
} # lr_scheduler.get_last_lr()[0]}
} # lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step % args.save_steps == 0:
torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=convert_to_torch_module(unet),
unet=torch_unet,
revision=args.revision,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
@@ -670,7 +656,7 @@ def main(args):
break
torch.cuda.synchronize()
unet = convert_to_torch_module(unet)
unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0:
pipeline = DiffusionPipeline.from_pretrained(