diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 1c1389b5c..20ec2a7c8 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer +from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai from colossalai.accelerator import get_accelerator @@ -232,10 +232,12 @@ def main() -> None: else nullcontext() ) with init_ctx: - model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) + model = LlamaForCausalLM.from_pretrained(args.pretrained) # Freeze part of parameters. if args.freeze_non_embeds_params: freeze_non_embeds_parameters(model=model) + # this is essential, otherwise the grad checkpoint will not work. + model.train() if args.use_grad_checkpoint: model.gradient_checkpointing_enable() @@ -277,8 +279,6 @@ def main() -> None: lr_scheduler=lr_scheduler, dataloader=dataloader, ) - if args.load_checkpoint is None: - booster.load_model(model, args.pretrained) torch.set_default_dtype(torch.float)