feat: current working train

This commit is contained in:
Zach Nussbaum 2023-05-15 14:17:02 +00:00
parent a69475a5f1
commit e8ccaea0bf

View File

@ -10,10 +10,8 @@ from peft import get_peft_model, LoraConfig, TaskType
from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data
from torchmetrics import MeanMetric
from tqdm import tqdm
from gpt4all.models import GPTJRForCausalLM
from gpt4all.train.metrics import f1_score, exact_match_score
from gpt4all.models import GPTJRForCausalLM, PythiaSeekForCausalLM
import wandb
import torch.distributed as dist
torch.backends.cuda.matmul.allow_tf32 = True
@ -70,7 +68,7 @@ def train(accelerator, config):
checkpoint = config["gradient_checkpointing"]
#ensures back compat with non retrieval models
if 'encoder_dim' in config:
with accelerator.main_process_first():
if "gptj" in config["model_name"]:
model = GPTJRForCausalLM.from_pretrained(config["model_name"],
revision=config['version'] if 'version' in config else None,
use_cache=False if checkpoint else True,
@ -78,6 +76,26 @@ def train(accelerator, config):
# hardcoded!! TODO: change to config
total_alpha_steps=10000,
)
elif "pythia" in config["model_name"]:
cross_attn_layer = config["cross_attn_layer"]
learnable_alpha = config["learnable_alpha"]
model, info = PythiaSeekForCausalLM.from_pretrained(config["model_name"],
revision=config['version'] if 'version' in config else None,
use_cache=False if checkpoint else True,
encoder_dim=config["encoder_dim"],
# hardcoded!! TODO: change to config
total_alpha_steps=2500,
# mem transformer uses 9/12, pythia1-b is 16 layers
cross_attn_layer=cross_attn_layer,
learnable_alpha=learnable_alpha,
output_loading_info=True
)
# freeze all pretrained layers
if config["freeze_pretrained"]:
for name, param in model.named_parameters():
if name not in info["missing_keys"]:
param.requires_grad = False
else:
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
use_cache=False if checkpoint else True,
@ -154,16 +172,20 @@ def train(accelerator, config):
main_process = accelerator.is_main_process
if learnable_alpha is False and isinstance(cross_attn_layer, int):
accelerator.log({"alpha": model.gpt_neox.layers[cross_attn_layer]._update_alpha(0)}, step=0)
for epoch in range(config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
model.train()
curr_step = step + len(train_dataloader) * epoch
outputs = model(input_ids=batch["input_ids"],
labels=batch["labels"],
encoder_hidden_states=batch["encoder_hidden_states"],
step=step)
)
#step=curr_step)
loss = outputs.loss
accelerator.print(f"Loss: {loss:.4f}")
if config["debug"]:
logits = outputs.logits
@ -191,10 +213,16 @@ def train(accelerator, config):
# log LR in case something weird happens
if config["wandb"]:
if step > 0 and step % (config["log_lr_every"] ) == 0:
curr_step = step + epoch * len(train_dataloader)
lr = optimizer.param_groups[0]["lr"]
accelerator.log({"lr": lr}, step=curr_step)
if learnable_alpha is False and isinstance(cross_attn_layer, int):
curr_alpha = model.gpt_neox.layers[cross_attn_layer]._update_alpha(curr_step)
accelerator.log({"alpha": curr_alpha}, step=curr_step)
elif isinstance(cross_attn_layer, int):
curr_alpha = model.gpt_neox.layers[cross_attn_layer].alpha.item()
accelerator.log({"alpha": curr_alpha}, step=curr_step)
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
if scheduler:
@ -203,11 +231,9 @@ def train(accelerator, config):
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
curr_step = step + epoch * len(train_dataloader)
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
curr_step = step + epoch * len(train_dataloader)
val_loss = evaluate(model, val_dataloader, step=curr_step, main_process=main_process)
log_train = {