Improve logic for selecting metrics (#5196)

Co-authored-by: Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
Yuanchen
2023-12-22 14:52:50 +08:00
committed by GitHub
parent 4fa689fca1
commit eae01b6740
4 changed files with 62 additions and 23 deletions

View File

@@ -1,5 +1,5 @@
import os
from typing import Dict, List
from typing import Dict, List, Union
import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper
import numpy as np
@@ -279,7 +279,9 @@ class DatasetEvaluator(object):
return self.evaluation_results
def get_evaluation_results(self, data: List[Dict], dataset_name: str, model_name: str, metrics: List[str]):
def get_evaluation_results(
self, data: Dict[str, Union[str, Dict]], dataset_name: str, model_name: str, metrics: List[str]
):
"""
Evaluate inference data on the given metrics.
@@ -290,10 +292,11 @@ class DatasetEvaluator(object):
metrics: Metrics used to evaluate.
"""
self.data = data
self.data = data["inference_results"]
self.dataset_name = dataset_name
self.dataset_class = data["dataset_class"]
self.model_name = model_name
self.categories = list(data.keys())
self.categories = list(self.data.keys())
self.metrics = metrics
self.judgements = {}
@@ -313,9 +316,7 @@ class DatasetEvaluator(object):
for metric in self.metrics:
# Train and reference split use same metric as test split.
self.suggested_categories[metric] = metric_helper.metrics4subcategory[self.dataset_name.split("_")[0]][
metric
]
self.suggested_categories[metric] = metric_helper.metrics4subcategory[self.dataset_class][metric]
if "ALL" in self.suggested_categories[metric]:
self.suggested_categories[metric] = self.categories
self.metric_total_length[metric] = self.total_length

View File

@@ -25,7 +25,7 @@ metrics4subcategory = {
"per_byte_ppl_score": ["ALL"],
},
# The commented are non 4-choice questions.
"agieval": {
"AGIEvalDataset": {
"combined_single_choice_accuracy": [
# "lsat-ar",
# "lsat-lr",
@@ -103,14 +103,14 @@ metrics4subcategory = {
],
"ppl_score": ["ALL"],
},
"cmmlu": {
"CMMLUDataset": {
"first_token_accuracy": ["ALL"],
"single_choice_accuracy": ["ALL"],
"perplexity": ["ALL"],
"ppl_score_over_choices": ["ALL"],
"ppl_score": ["ALL"],
},
"gaokaobench": {
"GaoKaoBenchDataset": {
"combined_single_choice_accuracy": [
"English MCQs",
"Biology MCQs",
@@ -170,7 +170,7 @@ metrics4subcategory = {
"ppl_score_over_choices": ["ALL"],
"ppl_score": ["ALL"],
},
"longbench": {
"LongBenchDataset": {
"f1_score": ["hotpotqa", "2wikimqa", "musique", "narrativeqa", "qasper", "multifieldqa_en", "triviaqa"],
"f1_zh_score": ["multifieldqa_zh"],
"rouge_score": ["gov_report", "qmsum", "multi_news", "samsum"],
@@ -183,7 +183,7 @@ metrics4subcategory = {
"perplexity": ["ALL"],
"ppl_score": ["ALL"],
},
"mmlu": {
"MMLUDataset": {
"first_token_accuracy": ["ALL"],
"single_choice_accuracy": ["ALL"],
"accuracy": ["ALL"],
@@ -191,11 +191,11 @@ metrics4subcategory = {
"ppl_score_over_choices": ["ALL"],
"ppl_score": ["ALL"],
},
"mtbench": {"mtbench_single_judge": ["ALL"]},
"cvalues": {"first_token_accuracy": ["ALL"]},
"safetybench_zh": {"first_token_accuracy": ["ALL"]},
"safetybench_en": {"first_token_accuracy": ["ALL"]},
"gsm": {
"MTBenchDataset": {"mtbench_single_judge": ["ALL"]},
"CValuesDataset": {"first_token_accuracy": ["ALL"]},
"SafetyBenchZHDataset": {"first_token_accuracy": ["ALL"]},
"SafetyBenchENDataset": {"first_token_accuracy": ["ALL"]},
"GSMDataset": {
"loss_over_all_tokens": ["ALL"],
"gsm_accuracy": ["ALL"],
},