This commit is contained in:
Fazzie
2023-02-03 15:34:54 +08:00
parent 6e0faa70e0
commit cad1f50512
7 changed files with 831 additions and 658 deletions

View File

@@ -106,7 +106,20 @@ def get_parser(**parser_kwargs):
nargs="?",
help="disable test",
)
parser.add_argument("-p", "--project", help="name of new or path to existing project")
parser.add_argument(
"-p",
"--project",
help="name of new or path to existing project",
)
parser.add_argument(
"-c",
"--ckpt",
type=str,
const=True,
default="",
nargs="?",
help="load pretrained checkpoint from stable AI",
)
parser.add_argument(
"-d",
"--debug",
@@ -145,22 +158,7 @@ def get_parser(**parser_kwargs):
default=True,
help="scale base-lr by ngpu * batch_size * n_accumulate",
)
parser.add_argument(
"--use_fp16",
type=str2bool,
nargs="?",
const=True,
default=True,
help="whether to use fp16",
)
parser.add_argument(
"--flash",
type=str2bool,
const=True,
default=False,
nargs="?",
help="whether to use flash attention",
)
return parser
@@ -341,6 +339,12 @@ class SetupCallback(Callback):
except FileNotFoundError:
pass
# def on_fit_end(self, trainer, pl_module):
# if trainer.global_rank == 0:
# ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
# rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
# trainer.save_checkpoint(ckpt_path)
class ImageLogger(Callback):
@@ -536,6 +540,7 @@ if __name__ == "__main__":
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint")
if opt.resume:
rank_zero_info("Resuming from {}".format(opt.resume))
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
@@ -543,13 +548,13 @@ if __name__ == "__main__":
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir = "/".join(paths[:-2])
rank_zero_info("logdir: {}".format(logdir))
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split("/")
@@ -558,6 +563,7 @@ if __name__ == "__main__":
if opt.name:
name = "_" + opt.name
elif opt.base:
rank_zero_info("Using base config {}".format(opt.base))
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = "_" + cfg_name
@@ -566,6 +572,9 @@ if __name__ == "__main__":
nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname)
if opt.ckpt:
ckpt = opt.ckpt
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed)
@@ -582,14 +591,11 @@ if __name__ == "__main__":
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
print(trainer_config)
if not trainer_config["accelerator"] == "gpu":
del trainer_config["accelerator"]
cpu = True
print("Running on CPU")
else:
cpu = False
print("Running on GPU")
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
@@ -597,10 +603,12 @@ if __name__ == "__main__":
use_fp16 = trainer_config.get("precision", 32) == 16
if use_fp16:
config.model["params"].update({"use_fp16": True})
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
else:
config.model["params"].update({"use_fp16": False})
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
if ckpt is not None:
config.model["params"].update({"ckpt": ckpt})
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
model = instantiate_from_config(config.model)
# trainer and callbacks
@@ -639,7 +647,6 @@ if __name__ == "__main__":
# config the strategy, defualt is ddp
if "strategy" in trainer_config:
strategy_cfg = trainer_config["strategy"]
print("Using strategy: {}".format(strategy_cfg["target"]))
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
else:
strategy_cfg = {
@@ -648,7 +655,6 @@ if __name__ == "__main__":
"find_unused_parameters": False
}
}
print("Using strategy: DDPStrategy")
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
@@ -664,7 +670,6 @@ if __name__ == "__main__":
}
}
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3
@@ -673,7 +678,6 @@ if __name__ == "__main__":
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
@@ -710,8 +714,6 @@ if __name__ == "__main__":
"target": "main.CUDACallback"
},
}
if version.parse(pl.__version__) >= version.parse('1.4.0'):
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
@@ -737,15 +739,11 @@ if __name__ == "__main__":
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
elif 'ignore_keys_callback' in callbacks_cfg:
del callbacks_cfg['ignore_keys_callback']
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###
trainer.logdir = logdir
# data
data = instantiate_from_config(config.data)
@@ -754,9 +752,9 @@ if __name__ == "__main__":
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
print("#### Data #####")
for k in data.datasets:
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
# configure learning rate
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
@@ -768,17 +766,17 @@ if __name__ == "__main__":
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else:
accumulate_grad_batches = 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
rank_zero_info(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
else:
model.learning_rate = base_lr
print("++++ NOT USING LR SCALING ++++")
print(f"Setting learning rate to {model.learning_rate:.2e}")
rank_zero_info("++++ NOT USING LR SCALING ++++")
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
# allow checkpointing via USR1
def melk(*args, **kwargs):