mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 20:09:58 +00:00
feat: current working train
This commit is contained in:
parent
a69475a5f1
commit
e8ccaea0bf
@ -10,10 +10,8 @@ from peft import get_peft_model, LoraConfig, TaskType
|
|||||||
from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data
|
from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data
|
||||||
from torchmetrics import MeanMetric
|
from torchmetrics import MeanMetric
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from gpt4all.models import GPTJRForCausalLM
|
from gpt4all.models import GPTJRForCausalLM, PythiaSeekForCausalLM
|
||||||
from gpt4all.train.metrics import f1_score, exact_match_score
|
|
||||||
import wandb
|
import wandb
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
@ -70,7 +68,7 @@ def train(accelerator, config):
|
|||||||
checkpoint = config["gradient_checkpointing"]
|
checkpoint = config["gradient_checkpointing"]
|
||||||
#ensures back compat with non retrieval models
|
#ensures back compat with non retrieval models
|
||||||
if 'encoder_dim' in config:
|
if 'encoder_dim' in config:
|
||||||
with accelerator.main_process_first():
|
if "gptj" in config["model_name"]:
|
||||||
model = GPTJRForCausalLM.from_pretrained(config["model_name"],
|
model = GPTJRForCausalLM.from_pretrained(config["model_name"],
|
||||||
revision=config['version'] if 'version' in config else None,
|
revision=config['version'] if 'version' in config else None,
|
||||||
use_cache=False if checkpoint else True,
|
use_cache=False if checkpoint else True,
|
||||||
@ -78,6 +76,26 @@ def train(accelerator, config):
|
|||||||
# hardcoded!! TODO: change to config
|
# hardcoded!! TODO: change to config
|
||||||
total_alpha_steps=10000,
|
total_alpha_steps=10000,
|
||||||
)
|
)
|
||||||
|
elif "pythia" in config["model_name"]:
|
||||||
|
cross_attn_layer = config["cross_attn_layer"]
|
||||||
|
learnable_alpha = config["learnable_alpha"]
|
||||||
|
model, info = PythiaSeekForCausalLM.from_pretrained(config["model_name"],
|
||||||
|
revision=config['version'] if 'version' in config else None,
|
||||||
|
use_cache=False if checkpoint else True,
|
||||||
|
encoder_dim=config["encoder_dim"],
|
||||||
|
# hardcoded!! TODO: change to config
|
||||||
|
total_alpha_steps=2500,
|
||||||
|
# mem transformer uses 9/12, pythia1-b is 16 layers
|
||||||
|
cross_attn_layer=cross_attn_layer,
|
||||||
|
learnable_alpha=learnable_alpha,
|
||||||
|
output_loading_info=True
|
||||||
|
)
|
||||||
|
# freeze all pretrained layers
|
||||||
|
if config["freeze_pretrained"]:
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if name not in info["missing_keys"]:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||||
use_cache=False if checkpoint else True,
|
use_cache=False if checkpoint else True,
|
||||||
@ -154,16 +172,20 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
main_process = accelerator.is_main_process
|
main_process = accelerator.is_main_process
|
||||||
|
|
||||||
|
if learnable_alpha is False and isinstance(cross_attn_layer, int):
|
||||||
|
accelerator.log({"alpha": model.gpt_neox.layers[cross_attn_layer]._update_alpha(0)}, step=0)
|
||||||
|
|
||||||
for epoch in range(config["num_epochs"]):
|
for epoch in range(config["num_epochs"]):
|
||||||
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
|
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
|
||||||
model.train()
|
model.train()
|
||||||
|
curr_step = step + len(train_dataloader) * epoch
|
||||||
outputs = model(input_ids=batch["input_ids"],
|
outputs = model(input_ids=batch["input_ids"],
|
||||||
labels=batch["labels"],
|
labels=batch["labels"],
|
||||||
encoder_hidden_states=batch["encoder_hidden_states"],
|
encoder_hidden_states=batch["encoder_hidden_states"],
|
||||||
step=step)
|
)
|
||||||
|
#step=curr_step)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
accelerator.print(f"Loss: {loss:.4f}")
|
|
||||||
|
|
||||||
if config["debug"]:
|
if config["debug"]:
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
@ -191,10 +213,16 @@ def train(accelerator, config):
|
|||||||
# log LR in case something weird happens
|
# log LR in case something weird happens
|
||||||
if config["wandb"]:
|
if config["wandb"]:
|
||||||
if step > 0 and step % (config["log_lr_every"] ) == 0:
|
if step > 0 and step % (config["log_lr_every"] ) == 0:
|
||||||
curr_step = step + epoch * len(train_dataloader)
|
|
||||||
lr = optimizer.param_groups[0]["lr"]
|
lr = optimizer.param_groups[0]["lr"]
|
||||||
accelerator.log({"lr": lr}, step=curr_step)
|
accelerator.log({"lr": lr}, step=curr_step)
|
||||||
|
|
||||||
|
if learnable_alpha is False and isinstance(cross_attn_layer, int):
|
||||||
|
curr_alpha = model.gpt_neox.layers[cross_attn_layer]._update_alpha(curr_step)
|
||||||
|
accelerator.log({"alpha": curr_alpha}, step=curr_step)
|
||||||
|
elif isinstance(cross_attn_layer, int):
|
||||||
|
curr_alpha = model.gpt_neox.layers[cross_attn_layer].alpha.item()
|
||||||
|
accelerator.log({"alpha": curr_alpha}, step=curr_step)
|
||||||
|
|
||||||
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
if scheduler:
|
if scheduler:
|
||||||
@ -203,11 +231,9 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
|
|
||||||
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
|
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
|
||||||
curr_step = step + epoch * len(train_dataloader)
|
|
||||||
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||||
|
|
||||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||||
curr_step = step + epoch * len(train_dataloader)
|
|
||||||
val_loss = evaluate(model, val_dataloader, step=curr_step, main_process=main_process)
|
val_loss = evaluate(model, val_dataloader, step=curr_step, main_process=main_process)
|
||||||
|
|
||||||
log_train = {
|
log_train = {
|
||||||
|
Loading…
Reference in New Issue
Block a user