mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
fix ckpt
This commit is contained in:
@@ -106,7 +106,20 @@ def get_parser(**parser_kwargs):
|
||||
nargs="?",
|
||||
help="disable test",
|
||||
)
|
||||
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--project",
|
||||
help="name of new or path to existing project",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--ckpt",
|
||||
type=str,
|
||||
const=True,
|
||||
default="",
|
||||
nargs="?",
|
||||
help="load pretrained checkpoint from stable AI",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
@@ -145,22 +158,7 @@ def get_parser(**parser_kwargs):
|
||||
default=True,
|
||||
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fp16",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=True,
|
||||
help="whether to use fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash",
|
||||
type=str2bool,
|
||||
const=True,
|
||||
default=False,
|
||||
nargs="?",
|
||||
help="whether to use flash attention",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -341,6 +339,12 @@ class SetupCallback(Callback):
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# def on_fit_end(self, trainer, pl_module):
|
||||
# if trainer.global_rank == 0:
|
||||
# ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
||||
# rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
|
||||
# trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
|
||||
@@ -536,6 +540,7 @@ if __name__ == "__main__":
|
||||
"If you want to resume training in a new log folder, "
|
||||
"use -n/--name in combination with --resume_from_checkpoint")
|
||||
if opt.resume:
|
||||
rank_zero_info("Resuming from {}".format(opt.resume))
|
||||
if not os.path.exists(opt.resume):
|
||||
raise ValueError("Cannot find {}".format(opt.resume))
|
||||
if os.path.isfile(opt.resume):
|
||||
@@ -543,13 +548,13 @@ if __name__ == "__main__":
|
||||
# idx = len(paths)-paths[::-1].index("logs")+1
|
||||
# logdir = "/".join(paths[:idx])
|
||||
logdir = "/".join(paths[:-2])
|
||||
rank_zero_info("logdir: {}".format(logdir))
|
||||
ckpt = opt.resume
|
||||
else:
|
||||
assert os.path.isdir(opt.resume), opt.resume
|
||||
logdir = opt.resume.rstrip("/")
|
||||
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
||||
|
||||
opt.resume_from_checkpoint = ckpt
|
||||
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
||||
opt.base = base_configs + opt.base
|
||||
_tmp = logdir.split("/")
|
||||
@@ -558,6 +563,7 @@ if __name__ == "__main__":
|
||||
if opt.name:
|
||||
name = "_" + opt.name
|
||||
elif opt.base:
|
||||
rank_zero_info("Using base config {}".format(opt.base))
|
||||
cfg_fname = os.path.split(opt.base[0])[-1]
|
||||
cfg_name = os.path.splitext(cfg_fname)[0]
|
||||
name = "_" + cfg_name
|
||||
@@ -566,6 +572,9 @@ if __name__ == "__main__":
|
||||
nowname = now + name + opt.postfix
|
||||
logdir = os.path.join(opt.logdir, nowname)
|
||||
|
||||
if opt.ckpt:
|
||||
ckpt = opt.ckpt
|
||||
|
||||
ckptdir = os.path.join(logdir, "checkpoints")
|
||||
cfgdir = os.path.join(logdir, "configs")
|
||||
seed_everything(opt.seed)
|
||||
@@ -582,14 +591,11 @@ if __name__ == "__main__":
|
||||
for k in nondefault_trainer_args(opt):
|
||||
trainer_config[k] = getattr(opt, k)
|
||||
|
||||
print(trainer_config)
|
||||
if not trainer_config["accelerator"] == "gpu":
|
||||
del trainer_config["accelerator"]
|
||||
cpu = True
|
||||
print("Running on CPU")
|
||||
else:
|
||||
cpu = False
|
||||
print("Running on GPU")
|
||||
trainer_opt = argparse.Namespace(**trainer_config)
|
||||
lightning_config.trainer = trainer_config
|
||||
|
||||
@@ -597,10 +603,12 @@ if __name__ == "__main__":
|
||||
use_fp16 = trainer_config.get("precision", 32) == 16
|
||||
if use_fp16:
|
||||
config.model["params"].update({"use_fp16": True})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
else:
|
||||
config.model["params"].update({"use_fp16": False})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
|
||||
if ckpt is not None:
|
||||
config.model["params"].update({"ckpt": ckpt})
|
||||
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
# trainer and callbacks
|
||||
@@ -639,7 +647,6 @@ if __name__ == "__main__":
|
||||
# config the strategy, defualt is ddp
|
||||
if "strategy" in trainer_config:
|
||||
strategy_cfg = trainer_config["strategy"]
|
||||
print("Using strategy: {}".format(strategy_cfg["target"]))
|
||||
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
|
||||
else:
|
||||
strategy_cfg = {
|
||||
@@ -648,7 +655,6 @@ if __name__ == "__main__":
|
||||
"find_unused_parameters": False
|
||||
}
|
||||
}
|
||||
print("Using strategy: DDPStrategy")
|
||||
|
||||
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
||||
|
||||
@@ -664,7 +670,6 @@ if __name__ == "__main__":
|
||||
}
|
||||
}
|
||||
if hasattr(model, "monitor"):
|
||||
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
||||
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
||||
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
||||
|
||||
@@ -673,7 +678,6 @@ if __name__ == "__main__":
|
||||
else:
|
||||
modelckpt_cfg = OmegaConf.create()
|
||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
||||
|
||||
@@ -710,8 +714,6 @@ if __name__ == "__main__":
|
||||
"target": "main.CUDACallback"
|
||||
},
|
||||
}
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
|
||||
|
||||
if "callbacks" in lightning_config:
|
||||
callbacks_cfg = lightning_config.callbacks
|
||||
@@ -737,15 +739,11 @@ if __name__ == "__main__":
|
||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
|
||||
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
|
||||
elif 'ignore_keys_callback' in callbacks_cfg:
|
||||
del callbacks_cfg['ignore_keys_callback']
|
||||
|
||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir ###
|
||||
trainer.logdir = logdir
|
||||
|
||||
# data
|
||||
data = instantiate_from_config(config.data)
|
||||
@@ -754,9 +752,9 @@ if __name__ == "__main__":
|
||||
# lightning still takes care of proper multiprocessing though
|
||||
data.prepare_data()
|
||||
data.setup()
|
||||
print("#### Data #####")
|
||||
|
||||
for k in data.datasets:
|
||||
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
||||
rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
||||
|
||||
# configure learning rate
|
||||
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
||||
@@ -768,17 +766,17 @@ if __name__ == "__main__":
|
||||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
||||
else:
|
||||
accumulate_grad_batches = 1
|
||||
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||
rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||
if opt.scale_lr:
|
||||
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
||||
print(
|
||||
rank_zero_info(
|
||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
|
||||
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
||||
else:
|
||||
model.learning_rate = base_lr
|
||||
print("++++ NOT USING LR SCALING ++++")
|
||||
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||
rank_zero_info("++++ NOT USING LR SCALING ++++")
|
||||
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||
|
||||
# allow checkpointing via USR1
|
||||
def melk(*args, **kwargs):
|
||||
|
Reference in New Issue
Block a user