diff --git a/train.py b/train.py index d8ea6161..cc75c5b9 100644 --- a/train.py +++ b/train.py @@ -115,48 +115,56 @@ def train(accelerator, config): "gradient_accumulation_steps" ] - for step, batch in enumerate(tqdm(train_dataloader)): - model.train() - outputs = model(**batch) - loss = outputs.loss - loss = loss / gradient_accumulation_steps + for epoch in range(config["num_epochs"]): + for step, batch in enumerate(tqdm(train_dataloader)): + model.train() + outputs = model(**batch) + loss = outputs.loss + loss = loss / gradient_accumulation_steps - accelerator.backward(loss) + accelerator.backward(loss) - # log LR in case something weird happens - if step > 0 and step % (config["eval_every"] // 10) == 0: - if config["wandb"]: - accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=step) + # log LR in case something weird happens + if step > 0 and step % (config["eval_every"] // 10) == 0: + if config["wandb"]: + accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=step) - if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - optimizer.step() - scheduler.step() - optimizer.zero_grad() + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + scheduler.step() + optimizer.zero_grad() - loss_values = accelerator.gather_for_metrics({"loss": loss.detach()}) - train_loss.update(loss_values["loss"]) + loss_values = accelerator.gather_for_metrics({"loss": loss.detach()}) + train_loss.update(loss_values["loss"]) - if step > 0 and step % config["save_every"] == 0: - accelerator.save_state(f"{config['output_dir']}/step_{step}") + if step > 0 and step % config["save_every"] == 0: + accelerator.save_state(f"{config['output_dir']}/step_{step}") - if step > 0 and step % config["eval_every"] == 0: - val_loss = evaluate(config, model, val_dataloader) + if step > 0 and step % config["eval_every"] == 0: + val_loss = evaluate(config, model, val_dataloader) - log_train = { - "train_loss": train_loss.compute() + log_train = { + "train_loss": train_loss.compute() + } + log_val = { + "val_loss": val_loss.compute() } - log_val = { - "val_loss": val_loss.compute() - } - if config["wandb"]: - accelerator.log({**log_train, **log_val}, step=step) + if config["wandb"]: + accelerator.log({**log_train, **log_val}, step=step) - accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}") - accelerator.print(format_metrics(log_train, "train", f" step {step} ")) - accelerator.print(format_metrics(log_val, "val", f" step {step} ")) + accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}") + accelerator.print(format_metrics(log_train, "train", f" step {step} ")) + accelerator.print(format_metrics(log_val, "val", f" step {step} ")) - train_loss.reset() + train_loss.reset() + + accelerator.print(f"Epoch {epoch} finished") + accelerator.print(f"Pushing to HF hub") + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + if accelerator.is_main_process: + unwrapped_model.push_to_hub(config["save_name"], private=True) accelerator.wait_for_everyone() @@ -168,7 +176,8 @@ def train(accelerator, config): state_dict=accelerator.get_state_dict(model), ) - unwrapped_model.push_to_hub(config["save_name"], private=True) + if accelerator.is_main_process: + unwrapped_model.push_to_hub(config["save_name"], private=True) accelerator.end_training()