From 27a9b2b10c91bb6b0116136066ae4c56e63f0abd Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Wed, 3 May 2023 21:39:05 +0000 Subject: [PATCH] fix: option for no schedule --- configs/train/finetune_gptjr.yaml | 1 + gpt4all/train/train_retrieval.py | 73 +++++++++++++++++++++---------- 2 files changed, 51 insertions(+), 23 deletions(-) diff --git a/configs/train/finetune_gptjr.yaml b/configs/train/finetune_gptjr.yaml index 5cfccb46..e71e81b0 100644 --- a/configs/train/finetune_gptjr.yaml +++ b/configs/train/finetune_gptjr.yaml @@ -33,6 +33,7 @@ lora: false warmup_steps: 500 num_epochs: 5 debug: false +scheduler: false # logging wandb: true diff --git a/gpt4all/train/train_retrieval.py b/gpt4all/train/train_retrieval.py index 2e4dfc2e..df0bb977 100644 --- a/gpt4all/train/train_retrieval.py +++ b/gpt4all/train/train_retrieval.py @@ -75,13 +75,15 @@ def train(accelerator, config): revision=config['version'] if 'version' in config else None, use_cache=False if checkpoint else True, encoder_dim=config["encoder_dim"], - total_alpha_steps=total_num_steps + # hardcoded!! TODO: change to config + total_alpha_steps=10000, ) else: model = AutoModelForCausalLM.from_pretrained(config["model_name"], use_cache=False if checkpoint else True, trust_remote_code=True) + accelerator.print(f"Training a {model.num_parameters():,} parameter model") if checkpoint: model.gradient_checkpointing_enable() @@ -106,28 +108,35 @@ def train(accelerator, config): # Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler - if ( - accelerator.state.deepspeed_plugin is None - or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config - ): - scheduler = get_scheduler( - name="cosine", - optimizer=optimizer, - num_warmup_steps=config["warmup_steps"] * accelerator.num_processes, - num_training_steps=total_num_steps, + if config["scheduler"] or "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config: + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + scheduler = get_scheduler( + name="cosine", + optimizer=optimizer, + num_warmup_steps=config["warmup_steps"] * accelerator.num_processes, + num_training_steps=total_num_steps, + ) + else: + scheduler = DummyScheduler( + optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"] + ) + model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( + model, optimizer, train_dataloader, val_dataloader, scheduler ) + scheduler = True else: - scheduler = DummyScheduler( - optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"] + model, optimizer, train_dataloader, val_dataloader = accelerator.prepare( + model, optimizer, train_dataloader, val_dataloader ) - - model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( - model, optimizer, train_dataloader, val_dataloader, scheduler - ) + scheduler = False - # setup for saving training states in case preemption - accelerator.register_for_checkpointing(scheduler) + # setup for saving training states in case preemption + if scheduler: + accelerator.register_for_checkpointing(scheduler) if config["checkpoint"]: accelerator.load_state(config["checkpoint"]) @@ -154,6 +163,22 @@ def train(accelerator, config): encoder_hidden_states=batch["encoder_hidden_states"], step=step) loss = outputs.loss + accelerator.print(f"Loss: {loss:.4f}") + + if config["debug"]: + logits = outputs.logits + pred_tokens = torch.argmax(logits, dim=-1) + labels = batch["labels"].clone() + + for row in range(labels.shape[0]): + curr_label = labels[row] + mask = curr_label != -100 + decoded_true = tokenizer.decode(curr_label[mask], skip_special_tokens=True) + decoded_pred = tokenizer.decode(pred_tokens[row][mask], skip_special_tokens=True) + + accelerator.print(f"Predicted tokens: {decoded_pred}") + accelerator.print(f"True tokens: {decoded_true}") + # gather loss before backprop in case of gradient accumulation loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) @@ -167,15 +192,17 @@ def train(accelerator, config): if config["wandb"]: if step > 0 and step % (config["log_lr_every"] ) == 0: curr_step = step + epoch * len(train_dataloader) - accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step) + lr = optimizer.param_groups[0]["lr"] + accelerator.log({"lr": lr}, step=curr_step) if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: optimizer.step() - scheduler.step() + if scheduler: + scheduler.step() optimizer.zero_grad() - if step > 0 and step % config["save_every"] == 0: + if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0: curr_step = step + epoch * len(train_dataloader) accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") @@ -194,17 +221,17 @@ def train(accelerator, config): curr_step = step + epoch * len(train_dataloader) accelerator.log({**log_train, **log_val}, step=curr_step) - accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}") + accelerator.print(f"Current LR: {optimizer.param_groups[0]['lr']}") accelerator.print(format_metrics(log_train, "train", f" step {step} ")) accelerator.print(format_metrics(log_val, "val", f" step {step} ")) train_loss.reset() accelerator.print(f"Epoch {epoch} finished") - accelerator.print(f"Pushing to HF hub") accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) if config["push_to_hub"]: + accelerator.print(f"Pushing to HF hub") try: if accelerator.is_main_process: unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)