[example] fix save_load bug for dreambooth (#2280)

This commit is contained in:
BlueRum
2023-01-03 17:13:29 +08:00
committed by GitHub
parent f027ef7913
commit 1405b4381e
5 changed files with 53 additions and 41 deletions

View File

@@ -11,6 +11,7 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.checkpoint
from copy import deepcopy
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
@@ -359,6 +360,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
@@ -381,6 +383,7 @@ def main(args):
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"clip_grad_norm": args.max_grad_norm,
}
colossalai.launch_from_torch(config=config)
pg = ProcessGroup()
@@ -465,21 +468,21 @@ def main(args):
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
low_cpu_mem_usage=False)
revision=args.revision,)
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
low_cpu_mem_usage=False)
revision=args.revision,)
with ColoInitContext(device='cpu'):
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
with ColoInitContext():
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
@@ -597,7 +600,7 @@ def main(args):
for epoch in range(args.num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
torch.cuda.reset_peak_memory_stats()
# Move batch to gpu
for key, value in batch.items():
batch[key] = value.to(get_current_device(), non_blocking=True)
@@ -653,7 +656,7 @@ def main(args):
optimizer.step()
lr_scheduler.step()
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1)
global_step += 1
@@ -678,13 +681,15 @@ def main(args):
break
torch.cuda.synchronize()
unet=convert_to_torch_module(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=convert_to_torch_module(unet),
unet=unet,
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])