mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 20:09:58 +00:00
feat: current working train
This commit is contained in:
parent
a69475a5f1
commit
e8ccaea0bf
@ -10,10 +10,8 @@ from peft import get_peft_model, LoraConfig, TaskType
|
||||
from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data
|
||||
from torchmetrics import MeanMetric
|
||||
from tqdm import tqdm
|
||||
from gpt4all.models import GPTJRForCausalLM
|
||||
from gpt4all.train.metrics import f1_score, exact_match_score
|
||||
from gpt4all.models import GPTJRForCausalLM, PythiaSeekForCausalLM
|
||||
import wandb
|
||||
import torch.distributed as dist
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
@ -70,7 +68,7 @@ def train(accelerator, config):
|
||||
checkpoint = config["gradient_checkpointing"]
|
||||
#ensures back compat with non retrieval models
|
||||
if 'encoder_dim' in config:
|
||||
with accelerator.main_process_first():
|
||||
if "gptj" in config["model_name"]:
|
||||
model = GPTJRForCausalLM.from_pretrained(config["model_name"],
|
||||
revision=config['version'] if 'version' in config else None,
|
||||
use_cache=False if checkpoint else True,
|
||||
@ -78,6 +76,26 @@ def train(accelerator, config):
|
||||
# hardcoded!! TODO: change to config
|
||||
total_alpha_steps=10000,
|
||||
)
|
||||
elif "pythia" in config["model_name"]:
|
||||
cross_attn_layer = config["cross_attn_layer"]
|
||||
learnable_alpha = config["learnable_alpha"]
|
||||
model, info = PythiaSeekForCausalLM.from_pretrained(config["model_name"],
|
||||
revision=config['version'] if 'version' in config else None,
|
||||
use_cache=False if checkpoint else True,
|
||||
encoder_dim=config["encoder_dim"],
|
||||
# hardcoded!! TODO: change to config
|
||||
total_alpha_steps=2500,
|
||||
# mem transformer uses 9/12, pythia1-b is 16 layers
|
||||
cross_attn_layer=cross_attn_layer,
|
||||
learnable_alpha=learnable_alpha,
|
||||
output_loading_info=True
|
||||
)
|
||||
# freeze all pretrained layers
|
||||
if config["freeze_pretrained"]:
|
||||
for name, param in model.named_parameters():
|
||||
if name not in info["missing_keys"]:
|
||||
param.requires_grad = False
|
||||
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||
use_cache=False if checkpoint else True,
|
||||
@ -154,16 +172,20 @@ def train(accelerator, config):
|
||||
|
||||
main_process = accelerator.is_main_process
|
||||
|
||||
if learnable_alpha is False and isinstance(cross_attn_layer, int):
|
||||
accelerator.log({"alpha": model.gpt_neox.layers[cross_attn_layer]._update_alpha(0)}, step=0)
|
||||
|
||||
for epoch in range(config["num_epochs"]):
|
||||
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
|
||||
model.train()
|
||||
curr_step = step + len(train_dataloader) * epoch
|
||||
outputs = model(input_ids=batch["input_ids"],
|
||||
labels=batch["labels"],
|
||||
encoder_hidden_states=batch["encoder_hidden_states"],
|
||||
step=step)
|
||||
)
|
||||
#step=curr_step)
|
||||
loss = outputs.loss
|
||||
accelerator.print(f"Loss: {loss:.4f}")
|
||||
|
||||
if config["debug"]:
|
||||
logits = outputs.logits
|
||||
@ -191,10 +213,16 @@ def train(accelerator, config):
|
||||
# log LR in case something weird happens
|
||||
if config["wandb"]:
|
||||
if step > 0 and step % (config["log_lr_every"] ) == 0:
|
||||
curr_step = step + epoch * len(train_dataloader)
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
accelerator.log({"lr": lr}, step=curr_step)
|
||||
|
||||
if learnable_alpha is False and isinstance(cross_attn_layer, int):
|
||||
curr_alpha = model.gpt_neox.layers[cross_attn_layer]._update_alpha(curr_step)
|
||||
accelerator.log({"alpha": curr_alpha}, step=curr_step)
|
||||
elif isinstance(cross_attn_layer, int):
|
||||
curr_alpha = model.gpt_neox.layers[cross_attn_layer].alpha.item()
|
||||
accelerator.log({"alpha": curr_alpha}, step=curr_step)
|
||||
|
||||
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
optimizer.step()
|
||||
if scheduler:
|
||||
@ -203,11 +231,9 @@ def train(accelerator, config):
|
||||
|
||||
|
||||
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
|
||||
curr_step = step + epoch * len(train_dataloader)
|
||||
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||
|
||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||
curr_step = step + epoch * len(train_dataloader)
|
||||
val_loss = evaluate(model, val_dataloader, step=curr_step, main_process=main_process)
|
||||
|
||||
log_train = {
|
||||
|
Loading…
Reference in New Issue
Block a user