fix: update train scripts and configs for other models (#1164)

* feat: falcon config

* feat: mpt config

* chore: gitignore

* refactor: step calculation

* fix: attention mask + shuffle on epoch end

* fix: return tensors

* fix: wait for everyone

* chore: config

* chore: ds config

* fix: remove ccols

* fix: logging and saving

* chore: add einops
This commit is contained in:
Zach Nussbaum
2023-07-12 15:18:24 -04:00
committed by GitHub
parent e8b19b8e82
commit 6c4f449b7a
9 changed files with 245 additions and 29 deletions

View File

@@ -1,5 +1,5 @@
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
import torch
from torch.optim import AdamW
from argparse import ArgumentParser
@@ -42,7 +42,7 @@ def train(accelerator, config):
accelerator.print(config)
accelerator.print(f"Using {accelerator.num_processes} GPUs")
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'], use_fast=False)
# if no pad token, set it to eos
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
@@ -53,6 +53,7 @@ def train(accelerator, config):
checkpoint = config["gradient_checkpointing"]
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
use_cache=False if checkpoint else True,
trust_remote_code=True)
@@ -86,7 +87,7 @@ def train(accelerator, config):
# decay to min_lr instead of 0
lr_ratio = config["min_lr"] / config["lr"]
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * (config["num_epochs"])
# instead of decaying to zero, decay to ratio of min_lr / lr
total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"]
accelerator.print(f"Total training steps: {total_num_steps}")
@@ -104,7 +105,7 @@ def train(accelerator, config):
)
else:
scheduler = DummyScheduler(
optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"]
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
)
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
@@ -117,26 +118,34 @@ def train(accelerator, config):
if config["checkpoint"]:
accelerator.load_state(config["checkpoint"])
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
path = os.path.basename(config["checkpoint"])
training_difference = os.path.splitext(path)[0]
resume_step = int(training_difference.replace("step_", ""))
accelerator.skip_first_batches(train_dataloader, resume_step)
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
accelerator.print(f"Resuming from step {resume_step}")
else:
resume_step = 0
# log gradients
if accelerator.is_main_process and config["wandb"]:
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
for epoch in range(config["num_epochs"]):
accelerator.wait_for_everyone()
for epoch in range(0, config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader)):
curr_step = epoch * len(train_dataloader) + step
model.train()
outputs = model(**batch)
loss = outputs.loss
# gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
if config["wandb"]:
accelerator.log({"loss": torch.mean(loss_values["loss"]).item()}, step=curr_step)
train_loss.update(loss_values["loss"])
loss = loss / gradient_accumulation_steps
@@ -144,9 +153,8 @@ def train(accelerator, config):
# get gradient norm of all params
# log LR in case something weird happens
if step > 0 and step % (config["eval_every"] // 10) == 0:
if step > 0 and step % (config["log_lr_every"]) == 0:
if config["wandb"]:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
@@ -156,7 +164,6 @@ def train(accelerator, config):
if step > 0 and step % config["save_every"] == 0:
curr_step = step + epoch * len(train_dataloader)
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
@@ -170,7 +177,6 @@ def train(accelerator, config):
}
if config["wandb"]:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({**log_train, **log_val}, step=curr_step)
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
@@ -181,8 +187,14 @@ def train(accelerator, config):
accelerator.print(f"Epoch {epoch} finished")
accelerator.print(f"Pushing to HF hub")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
try:
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
@@ -191,21 +203,16 @@ def train(accelerator, config):
accelerator.print(e)
accelerator.print(f"Failed to push to hub")
if config["num_epochs"] > 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/epoch_{epoch}",
f"{config['output_dir']}/final",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/final",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.end_training()