Support mtbench (#5025)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
This commit is contained in:
Yuanchen 2023-11-09 13:41:50 +08:00 committed by GitHub
parent f71e63b0f3
commit 239cd92eff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 312 additions and 13 deletions

View File

@ -6,6 +6,7 @@ from .colossalai import ColossalDataset
from .gaokaobench import GaoKaoBenchDataset from .gaokaobench import GaoKaoBenchDataset
from .longbench import LongBenchDataset from .longbench import LongBenchDataset
from .mmlu import MMLUDataset from .mmlu import MMLUDataset
from .mtbench import MTBenchDataset
__all__ = [ __all__ = [
"AGIEvalDataset", "AGIEvalDataset",
@ -16,4 +17,5 @@ __all__ = [
"LongBenchDataset", "LongBenchDataset",
"MMLUDataset", "MMLUDataset",
"ColossalDataset", "ColossalDataset",
"MTBenchDataset",
] ]

View File

@ -0,0 +1,72 @@
import copy
import json
import os
from collections import defaultdict
from typing import Dict, List
from colossal_eval.utils import get_json_list
from colossalai.logging import DistributedLogger
from .base import BaseDataset
default_inference_kwargs = {
"calculate_loss": False,
"all_classes": None,
"language": "English",
"pretrain": False,
"max_new_tokens": 1024,
"turns": 2,
}
class MTBenchDataset(BaseDataset):
"""
Dataset class for mt_bench dataset.
Data source: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/data/mt_bench/question.jsonl
This dataset class will convert the original dataset into the inference dataset.
"""
def __init__(self, path, logger, few_shot):
self.multiturn = True
self.dataset = self.load(path, logger, few_shot)
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
dataset = {"test": defaultdict(dict)}
file_path = os.path.join(path, "question.jsonl")
ref_path = os.path.join(path, "reference_answer/gpt-4.jsonl")
reference = defaultdict(list)
ref_origin = get_json_list(ref_path)
for ref in ref_origin:
reference[ref["question_id"]] = ref["choices"][0]["turns"]
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
question = json.loads(line)
category = question["category"]
turn_number = len(question["turns"])
data_point = {
"id": question["question_id"],
"dataset": "mtbench",
"split": "test",
"category": category,
"instruction": question["turns"],
"input": "",
"output": [],
"target": [""] * turn_number
if question["question_id"] not in reference
else reference[question["question_id"]],
}
if category in dataset["test"]:
dataset["test"][category]["data"].append(data_point)
else:
dataset["test"][category] = {
"data": [data_point],
"inference_kwargs": copy.deepcopy(default_inference_kwargs),
}
return dataset

View File

@ -1,12 +1,15 @@
import os
from typing import Dict, List from typing import Dict, List
import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper
import numpy as np import numpy as np
import tqdm import tqdm
from colossal_eval.utils import jdump
LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"] LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"]
LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"] LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"]
CombinedMetrics = ["combined_single_choice_accuracy"] CombinedMetrics = ["combined_single_choice_accuracy"]
GPTMetrics = ["mtbench_single_judge"]
OtherMetrics = [ OtherMetrics = [
"f1_score", "f1_score",
"f1_zh_score", "f1_zh_score",
@ -29,8 +32,9 @@ class DatasetEvaluator(object):
""" """
def __init__(self): def __init__(self, config_path: str, save_path: str):
pass self.config_path = config_path
self.save_path = save_path
def _calculate_label_metrics(self, metric: str, category: str): def _calculate_label_metrics(self, metric: str, category: str):
"""Calculate label-based metrics.""" """Calculate label-based metrics."""
@ -156,6 +160,24 @@ class DatasetEvaluator(object):
self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"])) self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"]))
self.evaluation_results[metric]["ALL"] += total_score * weight self.evaluation_results[metric]["ALL"] += total_score * weight
def _calculate_gpt_metrics(self, metric: str, category: str):
"""Calculate gpt metrics."""
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
metric_method = eval("gpt_helper." + metric)
judgements, avg_ratings = metric_method(self.data[category]["data"], self.config_path)
self.judgements[category] = judgements
self.evaluation_results[metric][category] = (np.mean(avg_ratings), len(self.data[category]["data"]))
self.evaluation_results[metric]["ALL"] += np.mean(avg_ratings) * weight
for i in range(avg_ratings.shape[0]):
if f"{metric}_{i+1}" not in self.evaluation_results:
self.evaluation_results[f"{metric}_{i+1}"] = {cat: 0 for cat in (["ALL"] + self.categories)}
self.evaluation_results[f"{metric}_{i+1}"][category] = (avg_ratings[i], len(self.data[category]["data"]))
self.evaluation_results[f"{metric}_{i+1}"]["ALL"] += avg_ratings[i] * weight
def _calculate_loss_metrics(self, metric: str, category: str): def _calculate_loss_metrics(self, metric: str, category: str):
"""Calculate perplexity.""" """Calculate perplexity."""
if metric == "perplexity": if metric == "perplexity":
@ -217,10 +239,20 @@ class DatasetEvaluator(object):
for category in self.suggested_categories[metric]: for category in self.suggested_categories[metric]:
self._calculate_combined_metrics(metric, category) self._calculate_combined_metrics(metric, category)
pbar.update(1) pbar.update(1)
elif metric in GPTMetrics:
for category in self.suggested_categories[metric]:
self._calculate_gpt_metrics(metric, category)
pbar.update(1)
elif metric in OtherMetrics: elif metric in OtherMetrics:
for category in self.suggested_categories[metric]: for category in self.suggested_categories[metric]:
self._calculate_other_metrics(metric, category) self._calculate_other_metrics(metric, category)
pbar.update(1) pbar.update(1)
else:
raise Exception(f"{metric} not supported.")
if self.judgements:
judgement_path = os.path.join(self.save_path, f"{self.model_name}_judgements.json")
jdump(self.judgements, judgement_path)
return self.evaluation_results return self.evaluation_results
@ -240,6 +272,7 @@ class DatasetEvaluator(object):
self.model_name = model_name self.model_name = model_name
self.categories = list(data.keys()) self.categories = list(data.keys())
self.metrics = metrics self.metrics = metrics
self.judgements = {}
self.evaluation_results = { self.evaluation_results = {
metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics

View File

@ -0,0 +1,151 @@
# Code adapted from https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge
import ast
import concurrent.futures
import copy
import json
import os
import re
import time
from typing import Any, Dict, List
import numpy as np
import openai
import tqdm
MODEL = "gpt-4"
API_MAX_RETRY = 16
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"
NEED_REF_CATS = ["math", "reasoning", "coding"]
one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]")
one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]")
def load_mt_prompts(prompt_file: str):
prompts = {}
with open(prompt_file) as fin:
for line in fin:
line = json.loads(line)
prompts[line["name"]] = line
return prompts
def get_mt_prompt(prompts: Dict[str, str], multiturn: bool, math: bool):
if math and multiturn:
return prompts["single-math-v1-multi-turn"]
elif math and not multiturn:
return prompts["single-math-v1"]
elif not math and multiturn:
return prompts["single-v1-multi-turn"]
elif not math and not multiturn:
return prompts["single-v1"]
def chat_compeletion_openai(messages: List[Dict], temperature: float = 0.0, max_tokens: int = 2048):
output = API_ERROR_OUTPUT
model = MODEL
for _ in range(API_MAX_RETRY):
try:
response = openai.ChatCompletion.create(
model=model,
messages=messages,
n=1,
temperature=temperature,
max_tokens=max_tokens,
)
output = response["choices"][0]["message"]["content"]
break
except openai.error.OpenAIError as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
return output
def get_mtbench_judgements(question: Dict[str, Any], prompts: Dict[str, str]):
id = question["id"]
judgement = {"id": id, "judgements": [], "ratings": []}
category = question["category"]
math = category in NEED_REF_CATS
turn_number = len(question["instruction"])
for num in range(turn_number):
assert (len(question["target"]) >= 1 and math) or not math
kwargs = {}
if num >= 1:
prompt = get_mt_prompt(prompts, multiturn=True, math=math)
if len(question["target"]) >= 1 and math:
kwargs = {f"ref_answer_{i+1}": question["target"][i] for i in range(len(question["target"]))}
user_prompt = prompt["prompt_template"].format(
question_1=question["instruction"][0],
question_2=question["instruction"][1],
answer_1=question["output"][0],
answer_2=question["output"][1],
**kwargs,
)
else:
prompt = get_mt_prompt(prompts, multiturn=False, math=math)
if len(question["target"]) >= 1 and math:
kwargs = {"ref_answer_1": question["target"][0]}
user_prompt = prompt["prompt_template"].format(
question=question["instruction"][0],
answer=question["output"][0],
**kwargs,
)
rating = -1
sys_prompt = prompt["system_prompt"]
messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_prompt}]
judgement_str = chat_compeletion_openai(messages, temperature=0.0, max_tokens=2048)
match = re.search(one_score_pattern, judgement_str)
if not match:
match = re.search(one_score_pattern_backup, judgement_str)
if match:
rating = ast.literal_eval(match.groups()[0])
else:
rating = -1
judgement["judgements"].append(judgement_str)
judgement["ratings"].append(rating)
return judgement
def mtbench_single_judge(data: List[Dict], config_path: str):
judgements = []
prompt_dir = os.path.dirname(config_path)
prompts = load_mt_prompts(os.path.join(prompt_dir, "mtbench_judge_prompts.jsonl"))
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = []
for i, question in enumerate(data):
future = executor.submit(get_mtbench_judgements, question, prompts)
futures.append(future)
for future in tqdm.tqdm(
concurrent.futures.as_completed(futures),
desc=f"MTBench single judge for {data[0]['category']}",
total=len(futures),
):
judgements.append(future.result())
judgements.sort(key=lambda x: x["id"])
judgements_by_id = {j["id"]: j for j in judgements}
data_to_dump = copy.deepcopy(data)
for d in data_to_dump:
id = d["id"]
d["judgements"] = judgements_by_id[id]["judgements"]
d["ratings"] = judgements_by_id[id]["ratings"]
avg_ratings = np.mean([j["ratings"] for j in judgements], axis=0)
return data_to_dump, avg_ratings

View File

@ -185,6 +185,7 @@ metrics4subcategory = {
"ppl_score_over_choices": ["ALL"], "ppl_score_over_choices": ["ALL"],
"ppl_score": ["ALL"], "ppl_score": ["ALL"],
}, },
"mtbench": {"mtbench_single_judge": ["ALL"]},
} }

View File

@ -333,9 +333,12 @@ class HuggingFaceModel(BaseModel):
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1
turn_desc = "" if turn == 0 else f"-turn{turn}"
bar = tqdm( bar = tqdm(
range(math.ceil(len(data) / self.batch_size)), range(math.ceil(len(data) / self.batch_size)),
desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps", desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps",
disable=not is_rank_0(), disable=not is_rank_0(),
) )
loss_fct = torch.nn.CrossEntropyLoss(reduction="none") loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@ -384,6 +387,9 @@ class HuggingFaceModel(BaseModel):
for j in range(len(batch_prompt)): for j in range(len(batch_prompt)):
if not pretrain: if not pretrain:
if isinstance(answers[i + j]["output"], list):
answers[i + j]["output"].append(batch_decodes[j].strip())
else:
answers[i + j]["output"] = batch_decodes[j].strip() answers[i + j]["output"] = batch_decodes[j].strip()
if isinstance(scores, torch.Tensor): if isinstance(scores, torch.Tensor):

View File

@ -171,6 +171,9 @@ def get_batch_prompt(
for b in batch: for b in batch:
few_shot_prefix = "" few_shot_prefix = ""
if few_shot_data is not None: if few_shot_data is not None:
assert not isinstance(b["instruction"], list), print(
f"When performing few-shot, {b['dataset']} shouldn't be a multiturn dataset."
)
# For few-shot, only need input. Otherwise use instruction (in AGIEval). # For few-shot, only need input. Otherwise use instruction (in AGIEval).
query_text = b["input"] if b.get("input", "") != "" else b["instruction"] query_text = b["input"] if b.get("input", "") != "" else b["instruction"]
@ -181,11 +184,24 @@ def get_batch_prompt(
raise Exception("When using few-shot, target answer should be a string.") raise Exception("When using few-shot, target answer should be a string.")
few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens) few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens)
else:
query_text = b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
conv.append_message(conv.roles[0], few_shot_prefix + query_text) conv.append_message(conv.roles[0], few_shot_prefix + query_text)
conv.append_message(conv.roles[1], None) conv.append_message(conv.roles[1], None)
else:
if not isinstance(b["instruction"], list):
query_text = (
b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
)
conv.append_message(conv.roles[0], query_text)
conv.append_message(conv.roles[1], None)
else:
assert len(b["instruction"]) >= len(b["output"]) + 1
cur_turns = len(b["output"])
for turn in range(cur_turns):
conv.append_message(conv.roles[0], b["instruction"][turn])
conv.append_message(conv.roles[1], b["output"][turn])
conv.append_message(conv.roles[0], b["instruction"][cur_turns])
conv.append_message(conv.roles[1], None)
batch_prompt.append(conv.get_prompt()) batch_prompt.append(conv.get_prompt())

View File

@ -11,7 +11,7 @@ def main(args):
evaluation_results = {dataset["name"]: {} for dataset in config["dataset"]} evaluation_results = {dataset["name"]: {} for dataset in config["dataset"]}
evaluation_results_table = {dataset["name"]: {} for dataset in config["dataset"]} evaluation_results_table = {dataset["name"]: {} for dataset in config["dataset"]}
evaluator = DatasetEvaluator() evaluator = DatasetEvaluator(args.config, args.evaluation_results_save_path)
for dataset_parameter in config["dataset"]: for dataset_parameter in config["dataset"]:
dataset_name = dataset_parameter["name"] dataset_name = dataset_parameter["name"]
@ -26,6 +26,8 @@ def main(args):
results = evaluator.get_evaluation_results(data, dataset_name, model_name, metrics) results = evaluator.get_evaluation_results(data, dataset_name, model_name, metrics)
for metric, score in results.items(): for metric, score in results.items():
if metric not in results_metric_model:
results_metric_model[metric] = {model["name"]: None for model in config["model"]}
results_metric_model[metric][model_name] = score["ALL"] results_metric_model[metric][model_name] = score["ALL"]
evaluation_results[dataset_name][model_name] = results evaluation_results[dataset_name][model_name] = results

View File

@ -71,6 +71,7 @@ def main(args):
inference_data = {} inference_data = {}
debug_args = {} debug_args = {}
few_shot_args = {} few_shot_args = {}
multiturn_args = {}
config = utils.jload(args.config) config = utils.jload(args.config)
@ -102,6 +103,13 @@ def main(args):
dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"]) dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
dataset_.save(save_path) dataset_.save(save_path)
if hasattr(dataset_, "multiturn") and dataset_.multiturn:
multiturn_args[dataset_name] = True
logger.info(f"{dataset_parameter['dataset_class']} is a multiturn dataset.")
else:
multiturn_args[dataset_name] = False
inference_data[dataset_name] = dataset_.dataset["test"] inference_data[dataset_name] = dataset_.dataset["test"]
for model_parameter in model_parameters: for model_parameter in model_parameters:
@ -117,7 +125,10 @@ def main(args):
for dataset_name, split_data in inference_data.items(): for dataset_name, split_data in inference_data.items():
start = 0 start = 0
prev_questions = None
for category, category_data in split_data.items(): for category, category_data in split_data.items():
num_turn = category_data["inference_kwargs"].get("turns", 1)
if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None: if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
@ -132,11 +143,16 @@ def main(args):
start = (start + redundant) % world_size start = (start + redundant) % world_size
for turn in range(num_turn):
if turn == 0:
questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]] questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
else:
questions = prev_questions
answers_per_rank = model_.inference( answers_per_rank = model_.inference(
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
) )
prev_questions = answers_per_rank
answers_to_dump["data"] = answers_per_rank answers_to_dump["data"] = answers_per_rank