diff --git a/train.py b/train.py index 2cf92aa8..463dd155 100644 --- a/train.py +++ b/train.py @@ -1,8 +1,9 @@ import os -from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW, get_scheduler +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler from transformers.trainer_pt_utils import get_parameter_names import torch import torch.nn as nn +from torch.optim import AdamW from argparse import ArgumentParser from read import read_config from accelerate import Accelerator @@ -24,7 +25,7 @@ def format_metrics(metrics, split, prefix=""): def evaluate(model, val_dataloader): model.eval() - val_loss = MeanMetric().to(model.device) + val_loss = MeanMetric(nan_strategy="error").to(model.device) with torch.no_grad(): for batch in tqdm(val_dataloader): @@ -43,7 +44,7 @@ def train(accelerator, config): accelerator.print(config) accelerator.print(f"Using {accelerator.num_processes} GPUs") - tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name']) + tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length']) # llama has no pad token, set it to new token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -51,7 +52,7 @@ def train(accelerator, config): with accelerator.main_process_first(): train_dataloader, val_dataloader = load_data(config, tokenizer) - + checkpoint = config["gradient_checkpointing"] model = AutoModelForCausalLM.from_pretrained(config["model_name"], @@ -139,17 +140,22 @@ def train(accelerator, config): # log gradients if accelerator.is_main_process and config["wandb"]: - wandb.watch(model, log_freq=config["log_grads_every"]) + wandb.watch(model, log_freq=config["log_grads_every"], log="all") for epoch in range(config["num_epochs"]): - train_loss = MeanMetric().to(model.device) + train_loss = MeanMetric(nan_strategy="error").to(model.device) for step, batch in enumerate(tqdm(train_dataloader)): model.train() outputs = model(**batch) loss = outputs.loss - loss = loss / gradient_accumulation_steps + # gather loss before backprop in case of gradient accumulation + loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) + train_loss.update(loss_values["loss"]) + + loss = loss / gradient_accumulation_steps accelerator.backward(loss) + # get gradient norm of all params # log LR in case something weird happens if step > 0 and step % (config["eval_every"] // 10) == 0: @@ -162,8 +168,6 @@ def train(accelerator, config): scheduler.step() optimizer.zero_grad() - loss_values = accelerator.gather_for_metrics({"loss": loss.detach()}) - train_loss.update(loss_values["loss"]) if step > 0 and step % config["save_every"] == 0: curr_step = step + epoch * len(train_dataloader)