From e8ccaea0bf0c4d2cdba8c8fb6934fea627833b10 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 15 May 2023 14:17:02 +0000 Subject: [PATCH] feat: current working train --- gpt4all/train/train_retrieval.py | 44 +++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/gpt4all/train/train_retrieval.py b/gpt4all/train/train_retrieval.py index df0bb977..0c6baaf2 100644 --- a/gpt4all/train/train_retrieval.py +++ b/gpt4all/train/train_retrieval.py @@ -10,10 +10,8 @@ from peft import get_peft_model, LoraConfig, TaskType from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data from torchmetrics import MeanMetric from tqdm import tqdm -from gpt4all.models import GPTJRForCausalLM -from gpt4all.train.metrics import f1_score, exact_match_score +from gpt4all.models import GPTJRForCausalLM, PythiaSeekForCausalLM import wandb -import torch.distributed as dist torch.backends.cuda.matmul.allow_tf32 = True @@ -70,7 +68,7 @@ def train(accelerator, config): checkpoint = config["gradient_checkpointing"] #ensures back compat with non retrieval models if 'encoder_dim' in config: - with accelerator.main_process_first(): + if "gptj" in config["model_name"]: model = GPTJRForCausalLM.from_pretrained(config["model_name"], revision=config['version'] if 'version' in config else None, use_cache=False if checkpoint else True, @@ -78,6 +76,26 @@ def train(accelerator, config): # hardcoded!! TODO: change to config total_alpha_steps=10000, ) + elif "pythia" in config["model_name"]: + cross_attn_layer = config["cross_attn_layer"] + learnable_alpha = config["learnable_alpha"] + model, info = PythiaSeekForCausalLM.from_pretrained(config["model_name"], + revision=config['version'] if 'version' in config else None, + use_cache=False if checkpoint else True, + encoder_dim=config["encoder_dim"], + # hardcoded!! TODO: change to config + total_alpha_steps=2500, + # mem transformer uses 9/12, pythia1-b is 16 layers + cross_attn_layer=cross_attn_layer, + learnable_alpha=learnable_alpha, + output_loading_info=True + ) + # freeze all pretrained layers + if config["freeze_pretrained"]: + for name, param in model.named_parameters(): + if name not in info["missing_keys"]: + param.requires_grad = False + else: model = AutoModelForCausalLM.from_pretrained(config["model_name"], use_cache=False if checkpoint else True, @@ -154,16 +172,20 @@ def train(accelerator, config): main_process = accelerator.is_main_process + if learnable_alpha is False and isinstance(cross_attn_layer, int): + accelerator.log({"alpha": model.gpt_neox.layers[cross_attn_layer]._update_alpha(0)}, step=0) + for epoch in range(config["num_epochs"]): train_loss = MeanMetric(nan_strategy="error").to(model.device) for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)): model.train() + curr_step = step + len(train_dataloader) * epoch outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], encoder_hidden_states=batch["encoder_hidden_states"], - step=step) + ) + #step=curr_step) loss = outputs.loss - accelerator.print(f"Loss: {loss:.4f}") if config["debug"]: logits = outputs.logits @@ -191,10 +213,16 @@ def train(accelerator, config): # log LR in case something weird happens if config["wandb"]: if step > 0 and step % (config["log_lr_every"] ) == 0: - curr_step = step + epoch * len(train_dataloader) lr = optimizer.param_groups[0]["lr"] accelerator.log({"lr": lr}, step=curr_step) + if learnable_alpha is False and isinstance(cross_attn_layer, int): + curr_alpha = model.gpt_neox.layers[cross_attn_layer]._update_alpha(curr_step) + accelerator.log({"alpha": curr_alpha}, step=curr_step) + elif isinstance(cross_attn_layer, int): + curr_alpha = model.gpt_neox.layers[cross_attn_layer].alpha.item() + accelerator.log({"alpha": curr_alpha}, step=curr_step) + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: optimizer.step() if scheduler: @@ -203,11 +231,9 @@ def train(accelerator, config): if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0: - curr_step = step + epoch * len(train_dataloader) accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): - curr_step = step + epoch * len(train_dataloader) val_loss = evaluate(model, val_dataloader, step=curr_step, main_process=main_process) log_train = {