mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[example] fix save_load bug for dreambooth (#2280)
This commit is contained in:
@@ -11,6 +11,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from copy import deepcopy
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
@@ -359,6 +360,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
||||
placement_policy=placememt_policy,
|
||||
pin_memory=True,
|
||||
search_range_mb=32)
|
||||
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||
@@ -381,6 +383,7 @@ def main(args):
|
||||
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||
"clip_grad_norm": args.max_grad_norm,
|
||||
}
|
||||
|
||||
colossalai.launch_from_torch(config=config)
|
||||
pg = ProcessGroup()
|
||||
|
||||
@@ -465,21 +468,21 @@ def main(args):
|
||||
|
||||
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
revision=args.revision,)
|
||||
|
||||
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
revision=args.revision,)
|
||||
|
||||
with ColoInitContext(device='cpu'):
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
with ColoInitContext():
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -597,7 +600,7 @@ def main(args):
|
||||
for epoch in range(args.num_train_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
# Move batch to gpu
|
||||
for key, value in batch.items():
|
||||
batch[key] = value.to(get_current_device(), non_blocking=True)
|
||||
@@ -653,7 +656,7 @@ def main(args):
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
@@ -678,13 +681,15 @@ def main(args):
|
||||
break
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
unet=convert_to_torch_module(unet)
|
||||
|
||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=convert_to_torch_module(unet),
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
||||
|
||||
|
Reference in New Issue
Block a user