mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-26 23:37:40 +00:00
fix: grad accum loss calc
This commit is contained in:
parent
e4e88dff33
commit
8f2eb1a583
22
train.py
22
train.py
@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW, get_scheduler
|
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.optim import AdamW
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from read import read_config
|
from read import read_config
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
@ -24,7 +25,7 @@ def format_metrics(metrics, split, prefix=""):
|
|||||||
|
|
||||||
def evaluate(model, val_dataloader):
|
def evaluate(model, val_dataloader):
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = MeanMetric().to(model.device)
|
val_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in tqdm(val_dataloader):
|
for batch in tqdm(val_dataloader):
|
||||||
@ -43,7 +44,7 @@ def train(accelerator, config):
|
|||||||
accelerator.print(config)
|
accelerator.print(config)
|
||||||
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'])
|
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
||||||
# llama has no pad token, set it to new token
|
# llama has no pad token, set it to new token
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
@ -51,7 +52,7 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
with accelerator.main_process_first():
|
with accelerator.main_process_first():
|
||||||
train_dataloader, val_dataloader = load_data(config, tokenizer)
|
train_dataloader, val_dataloader = load_data(config, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
checkpoint = config["gradient_checkpointing"]
|
checkpoint = config["gradient_checkpointing"]
|
||||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||||
@ -139,17 +140,22 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
# log gradients
|
# log gradients
|
||||||
if accelerator.is_main_process and config["wandb"]:
|
if accelerator.is_main_process and config["wandb"]:
|
||||||
wandb.watch(model, log_freq=config["log_grads_every"])
|
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
|
||||||
|
|
||||||
for epoch in range(config["num_epochs"]):
|
for epoch in range(config["num_epochs"]):
|
||||||
train_loss = MeanMetric().to(model.device)
|
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
for step, batch in enumerate(tqdm(train_dataloader)):
|
for step, batch in enumerate(tqdm(train_dataloader)):
|
||||||
model.train()
|
model.train()
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
loss = loss / gradient_accumulation_steps
|
|
||||||
|
|
||||||
|
# gather loss before backprop in case of gradient accumulation
|
||||||
|
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
||||||
|
train_loss.update(loss_values["loss"])
|
||||||
|
|
||||||
|
loss = loss / gradient_accumulation_steps
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
|
# get gradient norm of all params
|
||||||
|
|
||||||
# log LR in case something weird happens
|
# log LR in case something weird happens
|
||||||
if step > 0 and step % (config["eval_every"] // 10) == 0:
|
if step > 0 and step % (config["eval_every"] // 10) == 0:
|
||||||
@ -162,8 +168,6 @@ def train(accelerator, config):
|
|||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
|
|
||||||
train_loss.update(loss_values["loss"])
|
|
||||||
|
|
||||||
if step > 0 and step % config["save_every"] == 0:
|
if step > 0 and step % config["save_every"] == 0:
|
||||||
curr_step = step + epoch * len(train_dataloader)
|
curr_step = step + epoch * len(train_dataloader)
|
||||||
|
Loading…
Reference in New Issue
Block a user