mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 20:09:58 +00:00
feat: eval for retrieval
This commit is contained in:
parent
1b3f18bef2
commit
3736eda56a
18
configs/eval/evaluate_gpt4all_jr.yaml
Normal file
18
configs/eval/evaluate_gpt4all_jr.yaml
Normal file
@ -0,0 +1,18 @@
|
||||
# model/tokenizer
|
||||
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/epoch_2"
|
||||
tokenizer_name: "EleutherAI/gpt-j-6B"
|
||||
version: null
|
||||
gradient_checkpointing: true
|
||||
save_name: "nomic-ai/gpt-jr"
|
||||
encoder_dim: 384
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_validation"
|
||||
max_length: 1024
|
||||
batch_size: 32
|
||||
pct_test: 0.05
|
||||
q_column: "question"
|
||||
a_column: "answers"
|
||||
encoder_column: "neighbor_embeddings"
|
54
gpt4all/eval/eval_squad.py
Normal file
54
gpt4all/eval/eval_squad.py
Normal file
@ -0,0 +1,54 @@
|
||||
import torch
|
||||
from gpt4all.models import GPTJRForCausalLM
|
||||
from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data
|
||||
from gpt4all.train.metrics import f1_score, exact_match_score
|
||||
from gpt4all.utils.read import read_config
|
||||
from transformers import AutoTokenizer
|
||||
from argparse import ArgumentParser
|
||||
from tqdm import tqdm
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="config.yaml")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = read_config(args.config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataloader = load_retrieval_augmented_data(config, tokenizer, split_dataset=False)
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = GPTJRForCausalLM.from_pretrained(config["model_name"], use_cache=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Evaluate the model on the SQUAD dataset
|
||||
f1s = []
|
||||
exact_matches = []
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader):
|
||||
outputs = model(input_ids=batch["input_ids"].to(device),
|
||||
labels=batch["labels"].to(device),
|
||||
encoder_hidden_states=batch["encoder_hidden_states"].to(device))
|
||||
|
||||
predicted_tokens = outputs.logits.argmax(dim=-1)
|
||||
predicted = tokenizer.batch_decode(predicted_tokens, skip_special_tokens=True)
|
||||
|
||||
labels = batch["labels"]
|
||||
mask = labels == -100
|
||||
labels[mask] = tokenizer.pad_token_id
|
||||
ground_truth = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
f1 = f1_score(predicted, ground_truth)
|
||||
exact_match = exact_match_score(predicted, ground_truth)
|
||||
|
||||
f1s.extend(f1)
|
||||
exact_matches.extend(exact_match)
|
||||
|
||||
|
||||
print(torch.tensor(f1s).mean())
|
||||
print(torch.tensor(exact_matches).to(torch.float32).mean())
|
Loading…
Reference in New Issue
Block a user