mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[llama] polish training script and fix optim ckpt (#5368)
This commit is contained in:
@@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
@@ -232,7 +232,7 @@ def main() -> None:
|
||||
else nullcontext()
|
||||
)
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM.from_pretrained(args.pretrained)
|
||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||
# Freeze part of parameters.
|
||||
if args.freeze_non_embeds_params:
|
||||
freeze_non_embeds_parameters(model=model)
|
||||
@@ -277,6 +277,8 @@ def main() -> None:
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
if args.load_checkpoint is None:
|
||||
booster.load_model(model, args.pretrained)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
@@ -329,7 +331,12 @@ def main() -> None:
|
||||
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch=epoch)
|
||||
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps)
|
||||
pbar = tqdm(
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step // args.accumulation_steps,
|
||||
)
|
||||
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||
for step, batch in enumerate(dataloader, start=start_step):
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
@@ -369,6 +376,7 @@ def main() -> None:
|
||||
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||
deactivate_neftune(model, handle)
|
||||
|
||||
accelerator.empty_cache()
|
||||
save_checkpoint(
|
||||
save_dir=args.save_dir,
|
||||
booster=booster,
|
||||
|
Reference in New Issue
Block a user