mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-07 04:20:59 +00:00
feat: training script
This commit is contained in:
parent
80d810322a
commit
48e07be9e9
@ -1,41 +1,40 @@
|
|||||||
# model/tokenizer
|
# model/tokenizer
|
||||||
model_name: "nomic-ai/gpt4all-j"
|
model_name: "EleutherAI/gpt-j-6B"
|
||||||
tokenizer_name: "nomic-ai/gpt4all-j"
|
tokenizer_name: "EleutherAI/gpt-j-6B"
|
||||||
version: 'v1.2-jazzy'
|
version: null
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
save_name: # CHANGE
|
save_name: "nomic-ai/gpt-jr-decay-alpha"
|
||||||
|
encoder_dim: 384
|
||||||
|
|
||||||
# dataset
|
# dataset
|
||||||
streaming: false
|
streaming: false
|
||||||
num_proc: 64
|
num_proc: 64
|
||||||
dataset_path: "squad"
|
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train"
|
||||||
max_length: 1024
|
max_length: 1024
|
||||||
batch_size: 32
|
batch_size: 32
|
||||||
|
pct_test: 0.05
|
||||||
|
q_column: "question"
|
||||||
|
a_column: "answers"
|
||||||
|
encoder_column: "neighbor_embeddings"
|
||||||
|
|
||||||
#index
|
|
||||||
index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin"
|
|
||||||
index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki_sample_tokenized_embedded_with_text"
|
|
||||||
index_space: "cosine"
|
|
||||||
index_dim: 384
|
|
||||||
query_embedding_field: 'question'
|
|
||||||
|
|
||||||
# train dynamics
|
# train dynamics
|
||||||
lr: 2.0e-5
|
lr: 1.0e-4
|
||||||
min_lr: 0
|
min_lr: 0
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
eval_every: 500
|
eval_every: 50
|
||||||
eval_steps: 105
|
|
||||||
save_every: 500
|
save_every: 500
|
||||||
log_grads_every: 100
|
log_grads_every: 100
|
||||||
output_dir: # CHANGE
|
log_lr_every: 10
|
||||||
|
output_dir: "ckpts/decay_alpha"
|
||||||
checkpoint: null
|
checkpoint: null
|
||||||
lora: false
|
lora: false
|
||||||
warmup_steps: 500
|
warmup_steps: 500
|
||||||
num_epochs: 2
|
num_epochs: 5
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
wandb: false
|
wandb: true
|
||||||
wandb_entity: # CHANGE
|
wandb_entity: gpt4all
|
||||||
wandb_project_name: # CHANGE
|
wandb_project_name: retrieval
|
||||||
seed: 42
|
seed: 42
|
||||||
|
|
||||||
|
@ -5,58 +5,7 @@ import os
|
|||||||
import hnswlib
|
import hnswlib
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import DefaultDataCollator
|
from transformers import DefaultDataCollator
|
||||||
|
from .preprocess import tokenize_inputs
|
||||||
|
|
||||||
|
|
||||||
def tokenize_inputs(config, tokenizer, examples):
|
|
||||||
max_length = config["max_length"]
|
|
||||||
|
|
||||||
# hacky backward compatible
|
|
||||||
different_eos = tokenizer.eos_token != "</s>"
|
|
||||||
out = {"labels": [], "input_ids": []}
|
|
||||||
for prompt, response in zip(examples["prompt"], examples["response"]):
|
|
||||||
if different_eos:
|
|
||||||
if response.count("</s> \n") > 0:
|
|
||||||
response = response.replace("</s> \n", f"{tokenizer.eos_token} \n")
|
|
||||||
|
|
||||||
prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])
|
|
||||||
|
|
||||||
# hack if our prompt is super long
|
|
||||||
# we need to include some labels so we arbitrarily trunacate at max_length // 2
|
|
||||||
# if the length is too long
|
|
||||||
if prompt_len >= max_length // 2:
|
|
||||||
# if prompt is too long, truncate
|
|
||||||
# but make sure to truncate to at max 1024 tokens
|
|
||||||
new_len = min(max_length // 2, len(prompt) // 2)
|
|
||||||
prompt = prompt[:new_len]
|
|
||||||
# get new prompt length
|
|
||||||
prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item()
|
|
||||||
|
|
||||||
assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}"
|
|
||||||
|
|
||||||
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
|
||||||
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
|
||||||
|
|
||||||
labels = input_tokens.clone()
|
|
||||||
labels[:prompt_len] = -100
|
|
||||||
if len(labels) < max_length:
|
|
||||||
# pad to max_length with -100
|
|
||||||
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
|
|
||||||
|
|
||||||
assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}"
|
|
||||||
|
|
||||||
if (labels == -100).sum() == len(labels) - 1:
|
|
||||||
print(prompt)
|
|
||||||
print(response)
|
|
||||||
raise
|
|
||||||
|
|
||||||
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
|
||||||
out["labels"].append(labels)
|
|
||||||
out["input_ids"].append(input_tokens)
|
|
||||||
|
|
||||||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def load_data(config, tokenizer):
|
def load_data(config, tokenizer):
|
||||||
@ -86,13 +35,13 @@ def load_data(config, tokenizer):
|
|||||||
|
|
||||||
# tokenize inputs and return labels and attention mask
|
# tokenize inputs and return labels and attention mask
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
|
||||||
batched=True,
|
batched=True,
|
||||||
remove_columns=["source", "prompt"],
|
remove_columns=["source", "prompt"],
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
val_dataset = val_dataset.map(
|
val_dataset = val_dataset.map(
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
|
||||||
batched=True,
|
batched=True,
|
||||||
remove_columns=["source", "prompt"],
|
remove_columns=["source", "prompt"],
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -154,12 +103,12 @@ def load_data_for_inference(config, tokenizer):
|
|||||||
|
|
||||||
# tokenize inputs and return labels and attention mask
|
# tokenize inputs and return labels and attention mask
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
|
||||||
batched=True,
|
batched=True,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
val_dataset = val_dataset.map(
|
val_dataset = val_dataset.map(
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
|
||||||
batched=True,
|
batched=True,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
@ -7,11 +7,13 @@ from gpt4all.utils.read import read_config
|
|||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
|
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
|
||||||
from peft import get_peft_model, LoraConfig, TaskType
|
from peft import get_peft_model, LoraConfig, TaskType
|
||||||
from gpt4all.utils.data import load_data, load_retrieval_augmented_data
|
from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data
|
||||||
from torchmetrics import MeanMetric
|
from torchmetrics import MeanMetric
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from gpt4all.models import GPTJRForCausalLM
|
from gpt4all.models import GPTJRForCausalLM
|
||||||
|
from gpt4all.train.metrics import f1_score, exact_match_score
|
||||||
import wandb
|
import wandb
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
@ -22,15 +24,18 @@ def format_metrics(metrics, split, prefix=""):
|
|||||||
return log
|
return log
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, val_dataloader):
|
def evaluate(model, val_dataloader, step, main_process=False):
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = MeanMetric(nan_strategy="error").to(model.device)
|
val_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in tqdm(val_dataloader):
|
for batch in tqdm(val_dataloader, disable=not main_process):
|
||||||
loss = model(**batch).loss
|
outputs = model(input_ids=batch["input_ids"],
|
||||||
|
labels=batch["labels"],
|
||||||
|
encoder_hidden_states=batch["encoder_hidden_states"],
|
||||||
|
step=step)
|
||||||
|
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
|
||||||
|
|
||||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
|
|
||||||
|
|
||||||
val_loss.update(loss_values["loss"])
|
val_loss.update(loss_values["loss"])
|
||||||
|
|
||||||
@ -50,20 +55,18 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
|
|
||||||
with accelerator.main_process_first():
|
with accelerator.main_process_first():
|
||||||
|
train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer)
|
||||||
if 'index_path' in config:
|
|
||||||
train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer)
|
|
||||||
else:
|
|
||||||
train_dataloader, val_dataloader = load_data(config, tokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
checkpoint = config["gradient_checkpointing"]
|
checkpoint = config["gradient_checkpointing"]
|
||||||
#ensures back compat with non retrieval models
|
#ensures back compat with non retrieval models
|
||||||
if 'index_path' in config:
|
if 'encoder_dim' in config:
|
||||||
model = GPTJRForCausalLM.from_pretrained(config["model_name"],
|
with accelerator.main_process_first():
|
||||||
revision=config['version'],
|
model = GPTJRForCausalLM.from_pretrained(config["model_name"],
|
||||||
use_cache=False if checkpoint else True,
|
revision=config['version'] if 'version' in config else None,
|
||||||
trust_remote_code=True)
|
use_cache=False if checkpoint else True,
|
||||||
|
encoder_dim=config["encoder_dim"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||||
use_cache=False if checkpoint else True,
|
use_cache=False if checkpoint else True,
|
||||||
@ -117,13 +120,14 @@ def train(accelerator, config):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
scheduler = DummyScheduler(
|
scheduler = DummyScheduler(
|
||||||
optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"]
|
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
|
||||||
)
|
)
|
||||||
|
|
||||||
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
||||||
model, optimizer, train_dataloader, val_dataloader, scheduler
|
model, optimizer, train_dataloader, val_dataloader, scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# setup for saving training states in case preemption
|
# setup for saving training states in case preemption
|
||||||
accelerator.register_for_checkpointing(scheduler)
|
accelerator.register_for_checkpointing(scheduler)
|
||||||
|
|
||||||
@ -141,11 +145,16 @@ def train(accelerator, config):
|
|||||||
if accelerator.is_main_process and config["wandb"]:
|
if accelerator.is_main_process and config["wandb"]:
|
||||||
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
|
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
|
||||||
|
|
||||||
|
main_process = accelerator.is_main_process
|
||||||
|
|
||||||
for epoch in range(config["num_epochs"]):
|
for epoch in range(config["num_epochs"]):
|
||||||
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
for step, batch in enumerate(tqdm(train_dataloader)):
|
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
|
||||||
model.train()
|
model.train()
|
||||||
outputs = model(**batch)
|
outputs = model(input_ids=batch["input_ids"],
|
||||||
|
labels=batch["labels"],
|
||||||
|
encoder_hidden_states=batch["encoder_hidden_states"],
|
||||||
|
step=step)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
|
||||||
# gather loss before backprop in case of gradient accumulation
|
# gather loss before backprop in case of gradient accumulation
|
||||||
@ -157,8 +166,8 @@ def train(accelerator, config):
|
|||||||
# get gradient norm of all params
|
# get gradient norm of all params
|
||||||
|
|
||||||
# log LR in case something weird happens
|
# log LR in case something weird happens
|
||||||
if step > 0 and step % (config["eval_every"] // 10) == 0:
|
if config["wandb"]:
|
||||||
if config["wandb"]:
|
if step > 0 and step % (config["log_lr_every"] ) == 0:
|
||||||
curr_step = step + epoch * len(train_dataloader)
|
curr_step = step + epoch * len(train_dataloader)
|
||||||
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
|
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
|
||||||
|
|
||||||
@ -173,13 +182,14 @@ def train(accelerator, config):
|
|||||||
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||||
|
|
||||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||||
val_loss = evaluate(model, val_dataloader)
|
curr_step = step + epoch * len(train_dataloader)
|
||||||
|
val_loss = evaluate(model, val_dataloader, step=curr_step, main_process=main_process)
|
||||||
|
|
||||||
log_train = {
|
log_train = {
|
||||||
"train_loss": train_loss.compute()
|
"train_loss": train_loss.compute()
|
||||||
}
|
}
|
||||||
log_val = {
|
log_val = {
|
||||||
"val_loss": val_loss.compute()
|
"val_loss": val_loss.compute(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if config["wandb"]:
|
if config["wandb"]:
|
Loading…
Reference in New Issue
Block a user