diff --git a/train.py b/train.py
index 4344ee24..6c2c0515 100644
--- a/train.py
+++ b/train.py
@@ -1,5 +1,5 @@
import os
-from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW, get_scheduler
from transformers.trainer_pt_utils import get_parameter_names
import torch
import torch.nn as nn
@@ -11,6 +11,7 @@ from peft import get_peft_model, LoraConfig, TaskType
from data import load_data
from torchmetrics import MeanMetric
from tqdm import tqdm
+import wandb
def format_metrics(metrics, split, prefix=""):
@@ -20,17 +21,12 @@ def format_metrics(metrics, split, prefix=""):
return log
-def evaluate(config, model, val_dataloader):
+def evaluate(model, val_dataloader):
model.eval()
val_loss = MeanMetric().to(model.device)
with torch.no_grad():
- for i, batch in enumerate(
- tqdm(val_dataloader),
- ):
- if i == config["eval_steps"]:
- break
-
+ for batch in tqdm(val_dataloader):
loss = model(**batch).loss
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
@@ -49,8 +45,7 @@ def train(accelerator, config):
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'])
# llama has no pad token, set it to new token
if tokenizer.pad_token is None:
- # these tokens are already in the vocab, just not mapped correctly
- added_tokens = tokenizer.add_special_tokens({"bos_token": "", "eos_token": "", "pad_token": ""})
+ tokenizer.pad_token = tokenizer.eos_token
with accelerator.main_process_first():
@@ -61,10 +56,6 @@ def train(accelerator, config):
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
use_cache=False if checkpoint else True,
trust_remote_code=True)
-
- if added_tokens > 0:
- model.resize_token_embeddings(len(tokenizer))
-
if checkpoint:
model.gradient_checkpointing_enable()
@@ -77,19 +68,55 @@ def train(accelerator, config):
model.print_trainable_parameters()
optimizer_cls = (
- torch.optim.AdamW
+ AdamW
if accelerator.state.deepspeed_plugin is None
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
else DummyOptim
)
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": config["weight_decay"],
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+
# karpathy doesn't decay embeddding, maybe we should exclude
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
- optimizer = optimizer_cls(model.parameters(), lr=config["lr"])
+ optimizer = optimizer_cls(optimizer_grouped_parameters, lr=config["lr"])
- # scheduler defined in Deepspeed config
- scheduler = DummyScheduler(
- optimizer, warmup_num_steps=config["warmup_steps"],
+ if accelerator.state.deepspeed_plugin is not None:
+ gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
+ "gradient_accumulation_steps"
+ ]
+
+ # decay to min_lr instead of 0
+ lr_ratio = config["min_lr"] / config["lr"]
+ accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
+ total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
+ # instead of decaying to zero, decay to ratio of min_lr / lr
+ total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"]
+ accelerator.print(f"Total training steps: {total_num_steps}")
+
+ # Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
+ if (
+ accelerator.state.deepspeed_plugin is None
+ or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
+ ):
+ scheduler = get_scheduler(
+ name="cosine",
+ optimizer=optimizer,
+ num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
+ num_training_steps=total_num_steps * accelerator.num_processes,
+ )
+ else:
+ scheduler = DummyScheduler(
+ optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"]
)
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
@@ -108,14 +135,13 @@ def train(accelerator, config):
accelerator.skip_first_batches(train_dataloader, resume_step)
accelerator.print(f"Resuming from step {resume_step}")
- train_loss = MeanMetric().to(model.device)
- if accelerator.state.deepspeed_plugin is not None:
- gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
- "gradient_accumulation_steps"
- ]
+ # log gradients
+ if accelerator.is_local_main_process and config["wandb"]:
+ wandb.watch(model, log_freq=config["log_grads_every"])
for epoch in range(config["num_epochs"]):
+ train_loss = MeanMetric().to(model.device)
for step, batch in enumerate(tqdm(train_dataloader)):
model.train()
outputs = model(**batch)
@@ -139,9 +165,10 @@ def train(accelerator, config):
train_loss.update(loss_values["loss"])
if step > 0 and step % config["save_every"] == 0:
- accelerator.save_state(f"{config['output_dir']}/step_{step}")
+ curr_step = step + epoch * len(train_dataloader)
+ accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
- if step > 0 and step % config["eval_every"] == 0:
+ if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(config, model, val_dataloader)
log_train = {
@@ -166,7 +193,7 @@ def train(accelerator, config):
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
if accelerator.is_main_process:
- unwrapped_model.push_to_hub(config["save_name"] + "_first_epoch", private=True)
+ unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
accelerator.wait_for_everyone()
@@ -178,9 +205,6 @@ def train(accelerator, config):
state_dict=accelerator.get_state_dict(model),
)
- if accelerator.is_main_process:
- unwrapped_model.push_to_hub(config["save_name"], private=True)
-
accelerator.end_training()