mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
fixed model saving bugs
This commit is contained in:
@@ -667,9 +667,9 @@ def main(args):
|
|||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
if local_rank == 0:
|
|
||||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||||
|
if local_rank == 0:
|
||||||
if not os.path.exists(os.path.join(save_path, "config.json")):
|
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||||
|
@@ -693,9 +693,9 @@ def main(args):
|
|||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
if local_rank == 0:
|
|
||||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||||
|
if local_rank == 0:
|
||||||
if not os.path.exists(os.path.join(save_path, "config.json")):
|
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||||
|
Reference in New Issue
Block a user