diff --git a/applications/ColossalEval/colossal_eval/dataset/__init__.py b/applications/ColossalEval/colossal_eval/dataset/__init__.py index 4ea173198..5b029e267 100644 --- a/applications/ColossalEval/colossal_eval/dataset/__init__.py +++ b/applications/ColossalEval/colossal_eval/dataset/__init__.py @@ -6,6 +6,7 @@ from .colossalai import ColossalDataset from .gaokaobench import GaoKaoBenchDataset from .longbench import LongBenchDataset from .mmlu import MMLUDataset +from .mtbench import MTBenchDataset __all__ = [ "AGIEvalDataset", @@ -16,4 +17,5 @@ __all__ = [ "LongBenchDataset", "MMLUDataset", "ColossalDataset", + "MTBenchDataset", ] diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py new file mode 100644 index 000000000..9e74a4d82 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py @@ -0,0 +1,72 @@ +import copy +import json +import os +from collections import defaultdict +from typing import Dict, List + +from colossal_eval.utils import get_json_list + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +default_inference_kwargs = { + "calculate_loss": False, + "all_classes": None, + "language": "English", + "pretrain": False, + "max_new_tokens": 1024, + "turns": 2, +} + + +class MTBenchDataset(BaseDataset): + """ + Dataset class for mt_bench dataset. + Data source: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/data/mt_bench/question.jsonl + This dataset class will convert the original dataset into the inference dataset. + """ + + def __init__(self, path, logger, few_shot): + self.multiturn = True + self.dataset = self.load(path, logger, few_shot) + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"test": defaultdict(dict)} + + file_path = os.path.join(path, "question.jsonl") + ref_path = os.path.join(path, "reference_answer/gpt-4.jsonl") + + reference = defaultdict(list) + ref_origin = get_json_list(ref_path) + for ref in ref_origin: + reference[ref["question_id"]] = ref["choices"][0]["turns"] + + with open(file_path, "r", encoding="utf-8") as file: + for line in file: + question = json.loads(line) + category = question["category"] + turn_number = len(question["turns"]) + data_point = { + "id": question["question_id"], + "dataset": "mtbench", + "split": "test", + "category": category, + "instruction": question["turns"], + "input": "", + "output": [], + "target": [""] * turn_number + if question["question_id"] not in reference + else reference[question["question_id"]], + } + + if category in dataset["test"]: + dataset["test"][category]["data"].append(data_point) + else: + dataset["test"][category] = { + "data": [data_point], + "inference_kwargs": copy.deepcopy(default_inference_kwargs), + } + + return dataset diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py index 22de56b93..57ccd1aa6 100644 --- a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py @@ -1,12 +1,15 @@ +import os from typing import Dict, List import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper import numpy as np import tqdm +from colossal_eval.utils import jdump LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"] LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"] CombinedMetrics = ["combined_single_choice_accuracy"] +GPTMetrics = ["mtbench_single_judge"] OtherMetrics = [ "f1_score", "f1_zh_score", @@ -29,8 +32,9 @@ class DatasetEvaluator(object): """ - def __init__(self): - pass + def __init__(self, config_path: str, save_path: str): + self.config_path = config_path + self.save_path = save_path def _calculate_label_metrics(self, metric: str, category: str): """Calculate label-based metrics.""" @@ -156,6 +160,24 @@ class DatasetEvaluator(object): self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"])) self.evaluation_results[metric]["ALL"] += total_score * weight + def _calculate_gpt_metrics(self, metric: str, category: str): + """Calculate gpt metrics.""" + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + + metric_method = eval("gpt_helper." + metric) + + judgements, avg_ratings = metric_method(self.data[category]["data"], self.config_path) + self.judgements[category] = judgements + + self.evaluation_results[metric][category] = (np.mean(avg_ratings), len(self.data[category]["data"])) + self.evaluation_results[metric]["ALL"] += np.mean(avg_ratings) * weight + + for i in range(avg_ratings.shape[0]): + if f"{metric}_{i+1}" not in self.evaluation_results: + self.evaluation_results[f"{metric}_{i+1}"] = {cat: 0 for cat in (["ALL"] + self.categories)} + self.evaluation_results[f"{metric}_{i+1}"][category] = (avg_ratings[i], len(self.data[category]["data"])) + self.evaluation_results[f"{metric}_{i+1}"]["ALL"] += avg_ratings[i] * weight + def _calculate_loss_metrics(self, metric: str, category: str): """Calculate perplexity.""" if metric == "perplexity": @@ -217,10 +239,20 @@ class DatasetEvaluator(object): for category in self.suggested_categories[metric]: self._calculate_combined_metrics(metric, category) pbar.update(1) + elif metric in GPTMetrics: + for category in self.suggested_categories[metric]: + self._calculate_gpt_metrics(metric, category) + pbar.update(1) elif metric in OtherMetrics: for category in self.suggested_categories[metric]: self._calculate_other_metrics(metric, category) pbar.update(1) + else: + raise Exception(f"{metric} not supported.") + + if self.judgements: + judgement_path = os.path.join(self.save_path, f"{self.model_name}_judgements.json") + jdump(self.judgements, judgement_path) return self.evaluation_results @@ -240,6 +272,7 @@ class DatasetEvaluator(object): self.model_name = model_name self.categories = list(data.keys()) self.metrics = metrics + self.judgements = {} self.evaluation_results = { metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/gpt_judge.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/gpt_judge.py new file mode 100644 index 000000000..cd41dd7fd --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/gpt_judge.py @@ -0,0 +1,151 @@ +# Code adapted from https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge + +import ast +import concurrent.futures +import copy +import json +import os +import re +import time +from typing import Any, Dict, List + +import numpy as np +import openai +import tqdm + +MODEL = "gpt-4" + +API_MAX_RETRY = 16 +API_RETRY_SLEEP = 10 +API_ERROR_OUTPUT = "$ERROR$" + +NEED_REF_CATS = ["math", "reasoning", "coding"] + +one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]") +one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]") + + +def load_mt_prompts(prompt_file: str): + prompts = {} + with open(prompt_file) as fin: + for line in fin: + line = json.loads(line) + prompts[line["name"]] = line + return prompts + + +def get_mt_prompt(prompts: Dict[str, str], multiturn: bool, math: bool): + if math and multiturn: + return prompts["single-math-v1-multi-turn"] + elif math and not multiturn: + return prompts["single-math-v1"] + elif not math and multiturn: + return prompts["single-v1-multi-turn"] + elif not math and not multiturn: + return prompts["single-v1"] + + +def chat_compeletion_openai(messages: List[Dict], temperature: float = 0.0, max_tokens: int = 2048): + output = API_ERROR_OUTPUT + model = MODEL + for _ in range(API_MAX_RETRY): + try: + response = openai.ChatCompletion.create( + model=model, + messages=messages, + n=1, + temperature=temperature, + max_tokens=max_tokens, + ) + output = response["choices"][0]["message"]["content"] + break + except openai.error.OpenAIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + + return output + + +def get_mtbench_judgements(question: Dict[str, Any], prompts: Dict[str, str]): + id = question["id"] + judgement = {"id": id, "judgements": [], "ratings": []} + category = question["category"] + math = category in NEED_REF_CATS + turn_number = len(question["instruction"]) + + for num in range(turn_number): + assert (len(question["target"]) >= 1 and math) or not math + kwargs = {} + if num >= 1: + prompt = get_mt_prompt(prompts, multiturn=True, math=math) + if len(question["target"]) >= 1 and math: + kwargs = {f"ref_answer_{i+1}": question["target"][i] for i in range(len(question["target"]))} + user_prompt = prompt["prompt_template"].format( + question_1=question["instruction"][0], + question_2=question["instruction"][1], + answer_1=question["output"][0], + answer_2=question["output"][1], + **kwargs, + ) + else: + prompt = get_mt_prompt(prompts, multiturn=False, math=math) + if len(question["target"]) >= 1 and math: + kwargs = {"ref_answer_1": question["target"][0]} + user_prompt = prompt["prompt_template"].format( + question=question["instruction"][0], + answer=question["output"][0], + **kwargs, + ) + + rating = -1 + sys_prompt = prompt["system_prompt"] + messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_prompt}] + + judgement_str = chat_compeletion_openai(messages, temperature=0.0, max_tokens=2048) + match = re.search(one_score_pattern, judgement_str) + if not match: + match = re.search(one_score_pattern_backup, judgement_str) + if match: + rating = ast.literal_eval(match.groups()[0]) + else: + rating = -1 + + judgement["judgements"].append(judgement_str) + judgement["ratings"].append(rating) + + return judgement + + +def mtbench_single_judge(data: List[Dict], config_path: str): + judgements = [] + + prompt_dir = os.path.dirname(config_path) + prompts = load_mt_prompts(os.path.join(prompt_dir, "mtbench_judge_prompts.jsonl")) + + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [] + for i, question in enumerate(data): + future = executor.submit(get_mtbench_judgements, question, prompts) + futures.append(future) + + for future in tqdm.tqdm( + concurrent.futures.as_completed(futures), + desc=f"MTBench single judge for {data[0]['category']}", + total=len(futures), + ): + judgements.append(future.result()) + + judgements.sort(key=lambda x: x["id"]) + + judgements_by_id = {j["id"]: j for j in judgements} + + data_to_dump = copy.deepcopy(data) + + for d in data_to_dump: + id = d["id"] + d["judgements"] = judgements_by_id[id]["judgements"] + d["ratings"] = judgements_by_id[id]["ratings"] + + avg_ratings = np.mean([j["ratings"] for j in judgements], axis=0) + + return data_to_dump, avg_ratings diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py index 45a12756d..eae35bb9b 100644 --- a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py @@ -185,6 +185,7 @@ metrics4subcategory = { "ppl_score_over_choices": ["ALL"], "ppl_score": ["ALL"], }, + "mtbench": {"mtbench_single_judge": ["ALL"]}, } diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 47259c1db..693e02153 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -333,9 +333,12 @@ class HuggingFaceModel(BaseModel): self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} + turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1 + turn_desc = "" if turn == 0 else f"-turn{turn}" + bar = tqdm( range(math.ceil(len(data) / self.batch_size)), - desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps", + desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps", disable=not is_rank_0(), ) loss_fct = torch.nn.CrossEntropyLoss(reduction="none") @@ -384,7 +387,10 @@ class HuggingFaceModel(BaseModel): for j in range(len(batch_prompt)): if not pretrain: - answers[i + j]["output"] = batch_decodes[j].strip() + if isinstance(answers[i + j]["output"], list): + answers[i + j]["output"].append(batch_decodes[j].strip()) + else: + answers[i + j]["output"] = batch_decodes[j].strip() if isinstance(scores, torch.Tensor): answers[i + j]["softmax_over_choices"] = probs[j] diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py index 6c096a852..54ea21246 100644 --- a/applications/ColossalEval/colossal_eval/utils/conversation.py +++ b/applications/ColossalEval/colossal_eval/utils/conversation.py @@ -171,6 +171,9 @@ def get_batch_prompt( for b in batch: few_shot_prefix = "" if few_shot_data is not None: + assert not isinstance(b["instruction"], list), print( + f"When performing few-shot, {b['dataset']} shouldn't be a multiturn dataset." + ) # For few-shot, only need input. Otherwise use instruction (in AGIEval). query_text = b["input"] if b.get("input", "") != "" else b["instruction"] @@ -181,11 +184,24 @@ def get_batch_prompt( raise Exception("When using few-shot, target answer should be a string.") few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens) - else: - query_text = b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"] - conv.append_message(conv.roles[0], few_shot_prefix + query_text) - conv.append_message(conv.roles[1], None) + conv.append_message(conv.roles[0], few_shot_prefix + query_text) + conv.append_message(conv.roles[1], None) + else: + if not isinstance(b["instruction"], list): + query_text = ( + b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"] + ) + conv.append_message(conv.roles[0], query_text) + conv.append_message(conv.roles[1], None) + else: + assert len(b["instruction"]) >= len(b["output"]) + 1 + cur_turns = len(b["output"]) + for turn in range(cur_turns): + conv.append_message(conv.roles[0], b["instruction"][turn]) + conv.append_message(conv.roles[1], b["output"][turn]) + conv.append_message(conv.roles[0], b["instruction"][cur_turns]) + conv.append_message(conv.roles[1], None) batch_prompt.append(conv.get_prompt()) diff --git a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py index ec81cf0ce..5724c6e40 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py +++ b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py @@ -11,7 +11,7 @@ def main(args): evaluation_results = {dataset["name"]: {} for dataset in config["dataset"]} evaluation_results_table = {dataset["name"]: {} for dataset in config["dataset"]} - evaluator = DatasetEvaluator() + evaluator = DatasetEvaluator(args.config, args.evaluation_results_save_path) for dataset_parameter in config["dataset"]: dataset_name = dataset_parameter["name"] @@ -26,6 +26,8 @@ def main(args): results = evaluator.get_evaluation_results(data, dataset_name, model_name, metrics) for metric, score in results.items(): + if metric not in results_metric_model: + results_metric_model[metric] = {model["name"]: None for model in config["model"]} results_metric_model[metric][model_name] = score["ALL"] evaluation_results[dataset_name][model_name] = results diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index 657fc33bf..b3579424a 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -71,6 +71,7 @@ def main(args): inference_data = {} debug_args = {} few_shot_args = {} + multiturn_args = {} config = utils.jload(args.config) @@ -102,6 +103,13 @@ def main(args): dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"]) dataset_.save(save_path) + + if hasattr(dataset_, "multiturn") and dataset_.multiturn: + multiturn_args[dataset_name] = True + logger.info(f"{dataset_parameter['dataset_class']} is a multiturn dataset.") + else: + multiturn_args[dataset_name] = False + inference_data[dataset_name] = dataset_.dataset["test"] for model_parameter in model_parameters: @@ -117,7 +125,10 @@ def main(args): for dataset_name, split_data in inference_data.items(): start = 0 + prev_questions = None for category, category_data in split_data.items(): + num_turn = category_data["inference_kwargs"].get("turns", 1) + if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None: raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") @@ -132,11 +143,16 @@ def main(args): start = (start + redundant) % world_size - questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]] + for turn in range(num_turn): + if turn == 0: + questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]] + else: + questions = prev_questions - answers_per_rank = model_.inference( - questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] - ) + answers_per_rank = model_.inference( + questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] + ) + prev_questions = answers_per_rank answers_to_dump["data"] = answers_per_rank