feat: eval pythia

This commit is contained in:
Zach Nussbaum 2023-05-17 20:41:20 +00:00
parent fd3f9da403
commit db90a15911
3 changed files with 33 additions and 10 deletions

View File

@ -0,0 +1,18 @@
# model/tokenizer
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/learnable_alpha/epoch_0"
tokenizer_name: "EleutherAI/pythia-1b"
version: null
gradient_checkpointing: true
save_name: "nomic-ai/gpt-jr"
encoder_dim: 384
# dataset
streaming: false
num_proc: 64
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_validation"
max_length: 1024
batch_size: 32
pct_test: 0.05
q_column: "question"
a_column: "answers"
encoder_column: "neighbor_embeddings"

View File

@ -7,15 +7,15 @@ save_name: "nomic-ai/pythiaseek-large-bs"
push_to_hub: false
encoder_dim: 384
learnable_alpha: true
cross_attn_layer: 9
cross_attn_layer: 12
freeze_pretrained: false
# dataset
streaming: false
num_proc: 604
num_proc: 64
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train"
max_length: 1024
batch_size: 32
batch_size: 64
pct_test: 0.05
q_column: "question"
a_column: "answers"
@ -26,11 +26,11 @@ encoder_column: "neighbor_embeddings"
lr: 1.0e-4
min_lr: 0
weight_decay: 0.0
eval_every: 50
eval_every: 100
save_every: -1
log_grads_every: 100
log_lr_every: 10
output_dir: "ckpts/learnable_alpha"
output_dir: "ckpts/learnable_alpha_6b"
checkpoint: null
lora: false
warmup_steps: 500

View File

@ -1,5 +1,5 @@
import torch
from gpt4all.models import GPTJRForCausalLM
from gpt4all.models import GPTJRForCausalLM, PythiaSeekForCausalLM
from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data
from gpt4all.train.metrics import f1_score, exact_match_score
from gpt4all.utils.read import read_config
@ -22,7 +22,7 @@ dataloader = load_retrieval_augmented_data(config, tokenizer, split_dataset=Fals
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTJRForCausalLM.from_pretrained(config["model_name"], use_cache=False)
model = PythiaSeekForCausalLM.from_pretrained(config["model_name"], use_cache=False)
model.to(device)
model.eval()
@ -35,13 +35,18 @@ with torch.no_grad():
labels=batch["labels"].to(device),
encoder_hidden_states=batch["encoder_hidden_states"].to(device))
predicted_tokens = outputs.logits.argmax(dim=-1)
predicted = tokenizer.batch_decode(predicted_tokens, skip_special_tokens=True)
labels = batch["labels"]
mask = labels == -100
# since it's causal we predict the next token
predicted_tokens = outputs.logits.argmax(dim=-1)[:, :-1]
predicted_tokens[mask[:, 1:]] = tokenizer.pad_token_id
predicted = tokenizer.batch_decode(predicted_tokens, skip_special_tokens=True)
labels[mask] = tokenizer.pad_token_id
ground_truth = tokenizer.batch_decode(labels, skip_special_tokens=True)
print(f"Predicted: {predicted}")
print(f"Ground truth: {ground_truth}")
f1 = f1_score(predicted, ground_truth)
exact_match = exact_match_score(predicted, ground_truth)