mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[feature] ColossalEval: Evaluation Pipeline for LLMs (#4786)
* Add ColossalEval * Delete evaluate in Chat --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
0
applications/ColossalEval/colossal_eval/__init__.py
Normal file
0
applications/ColossalEval/colossal_eval/__init__.py
Normal file
19
applications/ColossalEval/colossal_eval/dataset/__init__.py
Normal file
19
applications/ColossalEval/colossal_eval/dataset/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .agieval import AGIEvalDataset
|
||||
from .base import BaseDataset
|
||||
from .ceval import CEvalDataset
|
||||
from .cmmlu import CMMLUDataset
|
||||
from .colossalai import ColossalDataset
|
||||
from .gaokaobench import GaoKaoBenchDataset
|
||||
from .longbench import LongBenchDataset
|
||||
from .mmlu import MMLUDataset
|
||||
|
||||
__all__ = [
|
||||
"AGIEvalDataset",
|
||||
"BaseDataset",
|
||||
"CEvalDataset",
|
||||
"CMMLUDataset",
|
||||
"GaoKaoBenchDataset",
|
||||
"LongBenchDataset",
|
||||
"MMLUDataset",
|
||||
"ColossalDataset",
|
||||
]
|
247
applications/ColossalEval/colossal_eval/dataset/agieval.py
Normal file
247
applications/ColossalEval/colossal_eval/dataset/agieval.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# Adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/dataset_loader.py.
|
||||
|
||||
import ast
|
||||
import glob
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
import pandas as pd
|
||||
from colossal_eval.utils import get_json_list
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
# define the datasets
|
||||
english_qa_datasets = [
|
||||
"lsat-ar",
|
||||
"lsat-lr",
|
||||
"lsat-rc",
|
||||
"logiqa-en",
|
||||
"sat-math",
|
||||
"sat-en",
|
||||
"aqua-rat",
|
||||
"sat-en-without-passage",
|
||||
"gaokao-english",
|
||||
]
|
||||
chinese_qa_datasets = [
|
||||
"logiqa-zh",
|
||||
"jec-qa-kd",
|
||||
"jec-qa-ca",
|
||||
"gaokao-chinese",
|
||||
"gaokao-geography",
|
||||
"gaokao-history",
|
||||
"gaokao-biology",
|
||||
"gaokao-chemistry",
|
||||
"gaokao-physics",
|
||||
"gaokao-mathqa",
|
||||
]
|
||||
english_cloze_datasets = ["math"]
|
||||
chinese_cloze_datasets = ["gaokao-mathcloze"]
|
||||
|
||||
multi_choice_datasets = ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"]
|
||||
math_output_datasets = {"gaokao-mathcloze", "math"}
|
||||
|
||||
default_inference_kwargs = {
|
||||
"calculate_loss": True,
|
||||
"all_classes": None,
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
||||
def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict:
|
||||
"""Modified from https://github.com/microsoft/AGIEval/blob/main/src/dataset_loader.py#L190"""
|
||||
try:
|
||||
all_classes = None
|
||||
passage = line["passage"] if line["passage"] is not None else ""
|
||||
|
||||
if dataset_name in english_qa_datasets:
|
||||
option_string = "ABCDEFG"
|
||||
count = len(line["options"])
|
||||
|
||||
input = (
|
||||
"Question: "
|
||||
+ line["question"]
|
||||
+ " "
|
||||
+ "Choose from the following options: "
|
||||
+ " ".join(line["options"])
|
||||
+ "\n"
|
||||
+ "Answer: "
|
||||
)
|
||||
|
||||
all_classes = list(option_string[0:count])
|
||||
|
||||
elif dataset_name in chinese_qa_datasets:
|
||||
option_string = "ABCDEFG"
|
||||
count = len(line["options"])
|
||||
|
||||
input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
|
||||
|
||||
all_classes = list(option_string[0:count])
|
||||
|
||||
elif dataset_name in english_cloze_datasets:
|
||||
input = "Question: " + line["question"] + "\n" + "Answer: "
|
||||
|
||||
elif dataset_name in chinese_cloze_datasets:
|
||||
input = "问题:" + line["question"] + "\n" + "答案:"
|
||||
|
||||
return {
|
||||
"instruction": input if not passage else passage + "\n\n" + input,
|
||||
"target": line["label"] if line["label"] else line["answer"],
|
||||
}, all_classes
|
||||
|
||||
except NameError:
|
||||
logger.info("Dataset not defined.")
|
||||
|
||||
|
||||
# process few-shot raw_prompts
|
||||
def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=False):
|
||||
skip_passage = False
|
||||
if dataset_name == "sat-en-without-passage":
|
||||
skip_passage = True
|
||||
dataset_name = "sat-en"
|
||||
demostrations = []
|
||||
# read the prompts by context and explanation
|
||||
context_row = [0, 1, 3, 5, 7, 9]
|
||||
explanation_row = [0, 2, 4, 6, 8, 10]
|
||||
raw_prompts_context = pd.read_csv(
|
||||
prompt_path, header=0, skiprows=lambda x: x not in context_row, keep_default_na=False
|
||||
)
|
||||
raw_prompts_explanation = pd.read_csv(
|
||||
prompt_path, header=0, skiprows=lambda x: x not in explanation_row, keep_default_na=False
|
||||
).replace(r"\n\n", "\n", regex=True)
|
||||
contexts = []
|
||||
for line in list(raw_prompts_context[dataset_name]):
|
||||
if line:
|
||||
# print(line)
|
||||
contexts.append(ast.literal_eval(line))
|
||||
explanations = [exp for exp in raw_prompts_explanation[dataset_name] if exp]
|
||||
|
||||
for idx, (con, exp) in enumerate(zip(contexts, explanations)):
|
||||
passage = con["passage"] if con["passage"] is not None and not skip_passage else ""
|
||||
question = con["question"]
|
||||
options = con["options"] if con["options"] is not None else ""
|
||||
label = con["label"] if con["label"] is not None else ""
|
||||
answer = con["answer"] if "answer" in con and con["answer"] is not None else ""
|
||||
|
||||
if dataset_name in english_qa_datasets:
|
||||
question_input = (
|
||||
"Question: "
|
||||
+ passage
|
||||
+ " "
|
||||
+ question
|
||||
+ "\n"
|
||||
+ "Choose from the following options: "
|
||||
+ " ".join(options)
|
||||
+ "\n"
|
||||
+ "Answer: {}".format(label)
|
||||
)
|
||||
elif dataset_name in chinese_qa_datasets:
|
||||
question_input = (
|
||||
"问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label)
|
||||
)
|
||||
elif dataset_name in english_cloze_datasets:
|
||||
question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)
|
||||
elif dataset_name in chinese_cloze_datasets:
|
||||
question_input = "问题:" + question + "\n" + "答案:{}".format(answer)
|
||||
else:
|
||||
raise ValueError(f"During loading few-sot examples, found unknown dataset: {dataset_name}")
|
||||
|
||||
if chat_mode:
|
||||
demostrations.append((question_input,))
|
||||
else:
|
||||
demostrations.append(question_input + "\n")
|
||||
|
||||
return demostrations
|
||||
|
||||
|
||||
class AGIEvalDataset(BaseDataset):
|
||||
"""
|
||||
Dataset wrapper for AGIEval dataset.
|
||||
Data source: https://github.com/microsoft/AGIEval
|
||||
This dataset class will convert the original dataset into the inference dataset.
|
||||
|
||||
A few dirty data needed to be manually corrected in the origin dataset:
|
||||
Issue link: https://github.com/microsoft/AGIEval/issues/16
|
||||
1. Invalid options in line 190 in gaokao-chemistry.jsonl.
|
||||
2. Option D (They may increase in value as those same resources become rare on Earth.) missing in line 17 in sat-en-without-passage.jsonl.
|
||||
3. Option D (They may increase in value as those same resources become rare on Earth.) missing in line 17 in sat-en.jsonl.
|
||||
4. Option D (No, because the data do not indicate whether the honeybees had been infected with mites.) missing in line 57 in sat-en-without-passage.jsonl.
|
||||
5. Option D (No, because the data do not indicate whether the honeybees had been infected with mites.) missing in line 57 in sat-en.jsonl.
|
||||
6. Option D (Published theories of scientists who developed earlier models of the Venus flytrap) missing in line 98 in sat-en-without-passage.jsonl.
|
||||
7. Option D (Published theories of scientists who developed earlier models of the Venus flytrap) missing in line 98 in sat-en.jsonl.
|
||||
8. Label is empty in line 212 in jec-qa-kd.jsonl. Content is also dirty.
|
||||
9. Actually, gaokao-mathqa.jsonl is also a multi-choice dataset. See line 149 286 287.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
|
||||
files = glob.glob(os.path.join(path, "*.jsonl"))
|
||||
files.sort()
|
||||
|
||||
if few_shot:
|
||||
prompt_path = os.path.join(path, "few_shot_prompts.csv")
|
||||
|
||||
for file in files:
|
||||
dataset_name = os.path.basename(file)[0 : -len(".jsonl")]
|
||||
|
||||
few_shot_data = []
|
||||
if few_shot:
|
||||
# process demo once if it is few-shot-CoT
|
||||
few_shot_data = combine_prompt(prompt_path, dataset_name, load_explanation=False, chat_mode=False)
|
||||
|
||||
dataset["test"][dataset_name] = {"data": []}
|
||||
|
||||
file_dir = os.path.join(path, file)
|
||||
|
||||
loaded_jsonl = get_json_list(file_dir)
|
||||
|
||||
# It's been tested that each data sample in one subcategory have same inference arguments.
|
||||
_, all_classes = get_prompt(loaded_jsonl[0], dataset_name, logger)
|
||||
inference_kwargs = deepcopy(default_inference_kwargs)
|
||||
if all_classes is not None and dataset_name not in multi_choice_datasets:
|
||||
inference_kwargs["all_classes"] = all_classes
|
||||
|
||||
if dataset_name in english_qa_datasets:
|
||||
inference_kwargs["language"] = "English"
|
||||
if dataset_name in chinese_qa_datasets:
|
||||
inference_kwargs["language"] = "Chinese"
|
||||
inference_kwargs["few_shot_data"] = few_shot_data
|
||||
|
||||
dataset["test"][dataset_name]["inference_kwargs"] = inference_kwargs
|
||||
|
||||
for line in loaded_jsonl:
|
||||
info, all_classes = get_prompt(line, dataset_name, logger)
|
||||
|
||||
# Convert multi-choice answers to a single string.
|
||||
# We will convert it back when evaluating.
|
||||
# We do this because if target is a list, it should be only used for multiple target answers.
|
||||
if dataset_name in multi_choice_datasets:
|
||||
if isinstance(info["target"], str) and len(info["target"]) > 1:
|
||||
# "gaokao-mathqa" actually contain multi-choice questions.
|
||||
# This if clause is specially used for it.
|
||||
info["target"] = "".join(info["target"].split())
|
||||
else:
|
||||
info["target"] = "".join(info["target"])
|
||||
|
||||
if isinstance(info["target"], list) and len(info["target"]) == 1:
|
||||
info["target"] = info["target"][0]
|
||||
|
||||
data_sample = {
|
||||
"dataset": "agieval",
|
||||
"split": "test",
|
||||
"category": dataset_name,
|
||||
"instruction": info["instruction"],
|
||||
"input": "",
|
||||
"output": "",
|
||||
"target": info["target"],
|
||||
}
|
||||
|
||||
dataset["test"][dataset_name]["data"].append(data_sample)
|
||||
|
||||
return dataset
|
24
applications/ColossalEval/colossal_eval/dataset/base.py
Normal file
24
applications/ColossalEval/colossal_eval/dataset/base.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from abc import abstractstaticmethod
|
||||
|
||||
from colossal_eval.utils import jdump
|
||||
|
||||
|
||||
class BaseDataset:
|
||||
"""
|
||||
Base class for dataset wrapper.
|
||||
|
||||
Args:
|
||||
path: The path to the original dataset.
|
||||
logger: Logger for the dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, path, logger, few_shot):
|
||||
self.dataset = self.load(path, logger, few_shot)
|
||||
|
||||
def save(self, save_path):
|
||||
"""Save the converted dataset"""
|
||||
jdump(self.dataset, save_path)
|
||||
|
||||
@abstractstaticmethod
|
||||
def load(path, logger):
|
||||
"""Load the original dataset and convert it into the inference dataset"""
|
132
applications/ColossalEval/colossal_eval/dataset/ceval.py
Normal file
132
applications/ColossalEval/colossal_eval/dataset/ceval.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import copy
|
||||
import csv
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
ceval_subject_mapping = {
|
||||
"computer_network": ["Computer Network", "计算机网络", "STEM"],
|
||||
"operating_system": ["Operating System", "操作系统", "STEM"],
|
||||
"computer_architecture": ["Computer Architecture", "计算机组成", "STEM"],
|
||||
"college_programming": ["College Programming", "大学编程", "STEM"],
|
||||
"college_physics": ["College Physics", "大学物理", "STEM"],
|
||||
"college_chemistry": ["College Chemistry", "大学化学", "STEM"],
|
||||
"advanced_mathematics": ["Advanced Mathematics", "高等数学", "STEM"],
|
||||
"probability_and_statistics": ["Probability and Statistics", "概率统计", "STEM"],
|
||||
"discrete_mathematics": ["Discrete Mathematics", "离散数学", "STEM"],
|
||||
"electrical_engineer": ["Electrical Engineer", "注册电气工程师", "STEM"],
|
||||
"metrology_engineer": ["Metrology Engineer", "注册计量师", "STEM"],
|
||||
"high_school_mathematics": ["High School Mathematics", "高中数学", "STEM"],
|
||||
"high_school_physics": ["High School Physics", "高中物理", "STEM"],
|
||||
"high_school_chemistry": ["High School Chemistry", "高中化学", "STEM"],
|
||||
"high_school_biology": ["High School Biology", "高中生物", "STEM"],
|
||||
"middle_school_mathematics": ["Middle School Mathematics", "初中数学", "STEM"],
|
||||
"middle_school_biology": ["Middle School Biology", "初中生物", "STEM"],
|
||||
"middle_school_physics": ["Middle School Physics", "初中物理", "STEM"],
|
||||
"middle_school_chemistry": ["Middle School Chemistry", "初中化学", "STEM"],
|
||||
"veterinary_medicine": ["Veterinary Medicine", "兽医学", "STEM"],
|
||||
"college_economics": ["College Economics", "大学经济学", "Social Science"],
|
||||
"business_administration": ["Business Administration", "工商管理", "Social Science"],
|
||||
"marxism": ["Marxism", "马克思主义基本原理", "Social Science"],
|
||||
"mao_zedong_thought": ["Mao Zedong Thought", "毛泽东思想和中国特色社会主义理论体系概论", "Social Science"],
|
||||
"education_science": ["Education Science", "教育学", "Social Science"],
|
||||
"teacher_qualification": ["Teacher Qualification", "教师资格", "Social Science"],
|
||||
"high_school_politics": ["High School Politics", "高中政治", "Social Science"],
|
||||
"high_school_geography": ["High School Geography", "高中地理", "Social Science"],
|
||||
"middle_school_politics": ["Middle School Politics", "初中政治", "Social Science"],
|
||||
"middle_school_geography": ["Middle School Geography", "初中地理", "Social Science"],
|
||||
"modern_chinese_history": ["Modern Chinese History", "近代史纲要", "Humanities"],
|
||||
"ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "思想道德修养与法律基础", "Humanities"],
|
||||
"logic": ["Logic", "逻辑学", "Humanities"],
|
||||
"law": ["Law", "法学", "Humanities"],
|
||||
"chinese_language_and_literature": ["Chinese Language and Literature", "中国语言文学", "Humanities"],
|
||||
"art_studies": ["Art Studies", "艺术学", "Humanities"],
|
||||
"professional_tour_guide": ["Professional Tour Guide", "导游资格", "Humanities"],
|
||||
"legal_professional": ["Legal Professional", "法律职业资格", "Humanities"],
|
||||
"high_school_chinese": ["High School Chinese", "高中语文", "Humanities"],
|
||||
"high_school_history": ["High School History", "高中历史", "Humanities"],
|
||||
"middle_school_history": ["Middle School History", "初中历史", "Humanities"],
|
||||
"civil_servant": ["Civil Servant", "公务员", "Other"],
|
||||
"sports_science": ["Sports Science", "体育学", "Other"],
|
||||
"plant_protection": ["Plant Protection", "植物保护", "Other"],
|
||||
"basic_medicine": ["Basic Medicine", "基础医学", "Other"],
|
||||
"clinical_medicine": ["Clinical Medicine", "临床医学", "Other"],
|
||||
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
|
||||
"accountant": ["Accountant", "注册会计师", "Other"],
|
||||
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
|
||||
"environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
|
||||
"tax_accountant": ["Tax Accountant", "税务师", "Other"],
|
||||
"physician": ["Physician", "医师资格", "Other"],
|
||||
}
|
||||
|
||||
default_inference_kwargs = {
|
||||
"calculate_loss": False,
|
||||
"all_classes": ["A", "B", "C", "D"],
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
||||
def get_few_shot_data(data: List[Dict]):
|
||||
few_shot_data = []
|
||||
for i in data:
|
||||
few_shot_data.append(i["input"] + i["target"])
|
||||
return few_shot_data
|
||||
|
||||
|
||||
class CEvalDataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for CEval dataset.
|
||||
Data source: https://huggingface.co/datasets/ceval/ceval-exam
|
||||
This dataset class will convert the original dataset into the inference dataset.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
files.sort()
|
||||
|
||||
for file in files:
|
||||
subject = file[0 : -len(f"_{split}.csv")]
|
||||
subject = ceval_subject_mapping[subject][1]
|
||||
|
||||
file_dir = os.path.join(path, split, file)
|
||||
|
||||
dataset[split][subject] = {"data": []}
|
||||
|
||||
# It's been tested that each data sample in one subcategory have same inference arguments.
|
||||
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
|
||||
|
||||
if split == "test" and few_shot:
|
||||
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
|
||||
dataset["dev"][subject]["data"]
|
||||
)
|
||||
|
||||
with open(file_dir, encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
_ = next(reader)
|
||||
for row in reader:
|
||||
# Dev split have answer and explanation so len(row) is 8
|
||||
# But test split doesn't contain answer and explanation, so len(row) is 6
|
||||
assert len(row) >= 6
|
||||
choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}"
|
||||
data_sample = {
|
||||
"dataset": "ceval",
|
||||
"split": split,
|
||||
"category": subject,
|
||||
"instruction": f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。",
|
||||
"input": f"题目:{row[1]}\n{choices}\n答案:",
|
||||
"output": "",
|
||||
"target": row[6] if split == "dev" else "",
|
||||
"id": int(row[0]),
|
||||
}
|
||||
|
||||
dataset[split][subject]["data"].append(data_sample)
|
||||
|
||||
return dataset
|
144
applications/ColossalEval/colossal_eval/dataset/cmmlu.py
Normal file
144
applications/ColossalEval/colossal_eval/dataset/cmmlu.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import copy
|
||||
import csv
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
cmmlu_subject_mapping = {
|
||||
"agronomy": "农学",
|
||||
"anatomy": "解剖学",
|
||||
"ancient_chinese": "古汉语",
|
||||
"arts": "艺术学",
|
||||
"astronomy": "天文学",
|
||||
"business_ethics": "商业伦理",
|
||||
"chinese_civil_service_exam": "中国公务员考试",
|
||||
"chinese_driving_rule": "中国驾驶规则",
|
||||
"chinese_food_culture": "中国饮食文化",
|
||||
"chinese_foreign_policy": "中国外交政策",
|
||||
"chinese_history": "中国历史",
|
||||
"chinese_literature": "中国文学",
|
||||
"chinese_teacher_qualification": "中国教师资格",
|
||||
"clinical_knowledge": "临床知识",
|
||||
"college_actuarial_science": "大学精算学",
|
||||
"college_education": "大学教育学",
|
||||
"college_engineering_hydrology": "大学工程水文学",
|
||||
"college_law": "大学法律",
|
||||
"college_mathematics": "大学数学",
|
||||
"college_medical_statistics": "大学医学统计",
|
||||
"college_medicine": "大学医学",
|
||||
"computer_science": "计算机科学",
|
||||
"computer_security": "计算机安全",
|
||||
"conceptual_physics": "概念物理学",
|
||||
"construction_project_management": "建设工程管理",
|
||||
"economics": "经济学",
|
||||
"education": "教育学",
|
||||
"electrical_engineering": "电气工程",
|
||||
"elementary_chinese": "小学语文",
|
||||
"elementary_commonsense": "小学常识",
|
||||
"elementary_information_and_technology": "小学信息技术",
|
||||
"elementary_mathematics": "初等数学",
|
||||
"ethnology": "民族学",
|
||||
"food_science": "食品科学",
|
||||
"genetics": "遗传学",
|
||||
"global_facts": "全球事实",
|
||||
"high_school_biology": "高中生物",
|
||||
"high_school_chemistry": "高中化学",
|
||||
"high_school_geography": "高中地理",
|
||||
"high_school_mathematics": "高中数学",
|
||||
"high_school_physics": "高中物理学",
|
||||
"high_school_politics": "高中政治",
|
||||
"human_sexuality": "人类性行为",
|
||||
"international_law": "国际法学",
|
||||
"journalism": "新闻学",
|
||||
"jurisprudence": "法理学",
|
||||
"legal_and_moral_basis": "法律与道德基础",
|
||||
"logical": "逻辑学",
|
||||
"machine_learning": "机器学习",
|
||||
"management": "管理学",
|
||||
"marketing": "市场营销",
|
||||
"marxist_theory": "马克思主义理论",
|
||||
"modern_chinese": "现代汉语",
|
||||
"nutrition": "营养学",
|
||||
"philosophy": "哲学",
|
||||
"professional_accounting": "专业会计",
|
||||
"professional_law": "专业法学",
|
||||
"professional_medicine": "专业医学",
|
||||
"professional_psychology": "专业心理学",
|
||||
"public_relations": "公共关系",
|
||||
"security_study": "安全研究",
|
||||
"sociology": "社会学",
|
||||
"sports_science": "体育学",
|
||||
"traditional_chinese_medicine": "中医中药",
|
||||
"virology": "病毒学",
|
||||
"world_history": "世界历史",
|
||||
"world_religions": "世界宗教",
|
||||
}
|
||||
|
||||
default_inference_kwargs = {
|
||||
"calculate_loss": True,
|
||||
"all_classes": ["A", "B", "C", "D"],
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
||||
def get_few_shot_data(data: List[Dict]):
|
||||
few_shot_data = []
|
||||
for i in data:
|
||||
few_shot_data.append(i["input"] + i["target"])
|
||||
return few_shot_data
|
||||
|
||||
|
||||
class CMMLUDataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for CMMLU dataset.
|
||||
Data source: https://github.com/haonan-li/CMMLU/tree/master/data
|
||||
This dataset class will convert the original dataset into the inference dataset.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
files.sort()
|
||||
|
||||
for file in files:
|
||||
subject = file[0 : -len(".csv")]
|
||||
subject = cmmlu_subject_mapping[subject]
|
||||
|
||||
file_dir = os.path.join(path, split, file)
|
||||
|
||||
dataset[split][subject] = {"data": []}
|
||||
|
||||
# It's been tested that each data sample in one subcategory have same inference arguments.
|
||||
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
|
||||
|
||||
if split == "test" and few_shot:
|
||||
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
|
||||
dataset["dev"][subject]["data"]
|
||||
)
|
||||
|
||||
with open(file_dir, encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
_ = next(reader)
|
||||
for row in reader:
|
||||
assert len(row) == 7
|
||||
choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}"
|
||||
data_sample = {
|
||||
"dataset": "cmmlu",
|
||||
"split": split,
|
||||
"category": subject,
|
||||
"instruction": f"以下是关于{subject}的单项选择题,请直接给出正确答案的选项。",
|
||||
"input": f"题目:{row[1]}\n{choices}\n答案:",
|
||||
"output": "",
|
||||
"target": row[6],
|
||||
}
|
||||
|
||||
dataset[split][subject]["data"].append(data_sample)
|
||||
|
||||
return dataset
|
@@ -0,0 +1,70 @@
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
from colossal_eval.utils import jload
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
default_inference_kwargs = {
|
||||
"calculate_loss": False,
|
||||
"all_classes": None,
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"max_new_tokens": 256,
|
||||
}
|
||||
|
||||
# You can add your own subcategory questions and specify whether it is a single-choice question or has target answers and need to calculate loss.
|
||||
single_choice_question = set()
|
||||
calculate_loss = set()
|
||||
|
||||
|
||||
def get_data_per_category(data):
|
||||
data_per_category = defaultdict(list)
|
||||
for item in data:
|
||||
category = item["category"]
|
||||
data_per_category[category].append(item)
|
||||
|
||||
return data_per_category
|
||||
|
||||
|
||||
class ColossalDataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for Colossal dataset.
|
||||
This dataset class will convert the original dataset into the inference dataset.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
data = jload(path)
|
||||
data_per_category = get_data_per_category(data)
|
||||
categories = list(data_per_category.keys())
|
||||
|
||||
for category in categories:
|
||||
dataset["test"][category] = {"data": []}
|
||||
category_data = data_per_category[category]
|
||||
|
||||
dataset["test"][category]["inference_kwargs"] = deepcopy(default_inference_kwargs)
|
||||
|
||||
if category in calculate_loss:
|
||||
dataset["test"][category]["inference_kwargs"]["calculate_loss"] = True
|
||||
if category in single_choice_question:
|
||||
dataset["test"][category]["inference_kwargs"]["all_classes"] = ["A", "B", "C", "D"]
|
||||
|
||||
for item in category_data:
|
||||
data_sample = {
|
||||
"dataset": "colossal",
|
||||
"split": "test",
|
||||
"category": category,
|
||||
"instruction": item["instruction"],
|
||||
"input": item["input"],
|
||||
"output": "",
|
||||
"target": item["target"],
|
||||
"id": item["id"],
|
||||
}
|
||||
dataset["test"][category]["data"].append(data_sample)
|
||||
|
||||
return dataset
|
122
applications/ColossalEval/colossal_eval/dataset/gaokaobench.py
Normal file
122
applications/ColossalEval/colossal_eval/dataset/gaokaobench.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
multi_choice_datasets = [
|
||||
"Chinese Lang and Usage MCQs",
|
||||
"Chinese Modern Lit",
|
||||
"English Fill in Blanks",
|
||||
"English Reading Comp",
|
||||
"Geography MCQs",
|
||||
"Physics MCQs",
|
||||
"English Cloze Test",
|
||||
]
|
||||
|
||||
chinese_qa_datasets = [
|
||||
"Biology MCQs",
|
||||
"Chemistry MCQs",
|
||||
"Chinese Lang and Usage MCQs",
|
||||
"Chinese Modern Lit",
|
||||
"Geography MCQs",
|
||||
"History MCQs",
|
||||
"Math I MCQs",
|
||||
"Math II MCQs",
|
||||
"Physics MCQs",
|
||||
"Political Science MCQs",
|
||||
]
|
||||
english_qa_datasets = ["English MCQs", "English Fill in Blanks", "English Reading Comp", "English Cloze Test"]
|
||||
|
||||
default_inference_kwargs = {
|
||||
"calculate_loss": True,
|
||||
"all_classes": None,
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
||||
def get_all_classes(instruction: str):
|
||||
letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
pattern = r"([A-Z]\. |[A-Z].|[A-Z]\.)"
|
||||
options = sorted(list(set(re.findall(pattern, instruction))))
|
||||
options = sorted(list(set([string[0] for string in options])))
|
||||
|
||||
for i in range(len(options)):
|
||||
if options[i] == letters[i]:
|
||||
continue
|
||||
else:
|
||||
return options[0:i]
|
||||
return options
|
||||
|
||||
|
||||
class GaoKaoBenchDataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for GAOKAO-Bench dataset.
|
||||
Data source: https://github.com/OpenLMLab/GAOKAO-Bench/tree/main/data
|
||||
This dataset class will convert the original dataset into the inference dataset.
|
||||
|
||||
A few typos needed to be manually corrected in the origin dataset, some of the following is fixed.
|
||||
Issue link: https://github.com/OpenLMLab/GAOKAO-Bench/issues/20
|
||||
1. Option C missing in index 111 in 2010-2022_Chemistry_MCQs.json
|
||||
2. Option B missing "." after it in index 16 in 2012-2022_English_Cloze_Test.json
|
||||
3. Option G missing "." after it in index 23 in 2012-2022_English_Cloze_Test.json
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]:
|
||||
files = os.listdir(os.path.join(path, "data", category))
|
||||
files.sort()
|
||||
|
||||
for file in files:
|
||||
subject = file[10:-5].split("_")
|
||||
subject = " ".join(subject)
|
||||
dataset["test"][subject] = {"data": []}
|
||||
|
||||
file_dir = os.path.join(path, "data", category, file)
|
||||
|
||||
with open(file_dir, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# It's been tested that each data sample in one subcategory have same inference arguments.
|
||||
inference_kwargs = deepcopy(default_inference_kwargs)
|
||||
if category == "Multiple-choice_Questions" and subject not in multi_choice_datasets:
|
||||
all_classes = get_all_classes(data["example"][0]["question"])
|
||||
inference_kwargs["all_classes"] = all_classes
|
||||
if subject in english_qa_datasets:
|
||||
inference_kwargs["language"] = "English"
|
||||
if subject in chinese_qa_datasets:
|
||||
inference_kwargs["language"] = "Chinese"
|
||||
|
||||
dataset["test"][subject]["inference_kwargs"] = inference_kwargs
|
||||
|
||||
for sample in data["example"]:
|
||||
# Convert multi-choice answers to a single string.
|
||||
# We will convert it back when evaluating.
|
||||
# We do this because if target is a list, it should be only used for multiple target answers.
|
||||
if subject in multi_choice_datasets:
|
||||
sample["answer"] = "".join(sample["answer"])
|
||||
|
||||
if isinstance(sample["answer"], list) and len(sample["answer"]) == 1:
|
||||
sample["answer"] = sample["answer"][0]
|
||||
|
||||
data_sample = {
|
||||
"dataset": "gaokaobench",
|
||||
"split": "test",
|
||||
"category": f"{category[:-10]}-{subject}",
|
||||
"instruction": sample["question"].strip() + "\n答案:",
|
||||
"input": "",
|
||||
"output": "",
|
||||
"target": sample["answer"],
|
||||
}
|
||||
|
||||
dataset["test"][subject]["data"].append(data_sample)
|
||||
|
||||
return dataset
|
120
applications/ColossalEval/colossal_eval/dataset/longbench.py
Normal file
120
applications/ColossalEval/colossal_eval/dataset/longbench.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
from colossal_eval.utils import get_json_list
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
dataset2prompt = {
|
||||
"narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
|
||||
"qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:',
|
||||
"multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
|
||||
"multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
|
||||
"hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
|
||||
"2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
|
||||
"musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
|
||||
"dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
|
||||
"gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
|
||||
"qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
|
||||
"multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
|
||||
"vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
|
||||
"trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
|
||||
"triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
|
||||
"samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
|
||||
"lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
|
||||
"passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
|
||||
"passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ',
|
||||
"passage_retrieval_zh": '以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1","段落2"等格式\n\n答案是:',
|
||||
"lcc": "Please complete the code given below. \n{context}Next line of code:\n",
|
||||
"repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n",
|
||||
}
|
||||
|
||||
dataset2maxlen = {
|
||||
"narrativeqa": 128,
|
||||
"qasper": 128,
|
||||
"multifieldqa_en": 64,
|
||||
"multifieldqa_zh": 64,
|
||||
"hotpotqa": 32,
|
||||
"2wikimqa": 32,
|
||||
"musique": 32,
|
||||
"dureader": 128,
|
||||
"gov_report": 512,
|
||||
"qmsum": 512,
|
||||
"multi_news": 512,
|
||||
"vcsum": 512,
|
||||
"trec": 64,
|
||||
"triviaqa": 32,
|
||||
"samsum": 128,
|
||||
"lsht": 64,
|
||||
"passage_count": 32,
|
||||
"passage_retrieval_en": 32,
|
||||
"passage_retrieval_zh": 32,
|
||||
"lcc": 64,
|
||||
"repobench-p": 64,
|
||||
}
|
||||
|
||||
default_inference_kwargs = {
|
||||
"calculate_loss": True,
|
||||
"all_classes": None,
|
||||
"language": "Chinese",
|
||||
"pretrain": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
||||
class LongBenchDataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for LongBench dataset.
|
||||
Data source: https://huggingface.co/datasets/THUDM/LongBench
|
||||
This dataset class will convert the original dataset into the inference dataset.
|
||||
|
||||
Issue link: https://github.com/THUDM/LongBench/issues/15 (fixed)
|
||||
There are duplicate target answers in `nq.jsonl`, but this doesn't affect evaluation results.
|
||||
Also doesn't affect perplexity calculation (the program only need to select the minimum loss).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
|
||||
files = os.listdir(path)
|
||||
files.sort()
|
||||
|
||||
for file in files:
|
||||
category = file[0:-6]
|
||||
|
||||
if category.endswith("_e"):
|
||||
continue
|
||||
|
||||
dataset["test"][category] = {"data": []}
|
||||
|
||||
file_dir = os.path.join(path, file)
|
||||
|
||||
loaded_jsonl = get_json_list(file_dir)
|
||||
|
||||
# It's been tested that each data sample in one subcategory have same inference arguments.
|
||||
inference_kwargs = deepcopy(default_inference_kwargs)
|
||||
if loaded_jsonl[0]["all_classes"] is not None:
|
||||
inference_kwargs["all_classes"] = loaded_jsonl[0]["all_classes"]
|
||||
inference_kwargs["max_new_tokens"] = dataset2maxlen[category]
|
||||
dataset["test"][category]["inference_kwargs"] = inference_kwargs
|
||||
|
||||
for sample in loaded_jsonl:
|
||||
prompt = dataset2prompt[category].format(**sample)
|
||||
|
||||
data_sample = {
|
||||
"dataset": "longbench",
|
||||
"split": "test",
|
||||
"category": category,
|
||||
"instruction": prompt,
|
||||
"input": "",
|
||||
"output": "",
|
||||
"target": sample["answers"],
|
||||
}
|
||||
|
||||
dataset["test"][category]["data"].append(data_sample)
|
||||
|
||||
return dataset
|
73
applications/ColossalEval/colossal_eval/dataset/mmlu.py
Normal file
73
applications/ColossalEval/colossal_eval/dataset/mmlu.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import copy
|
||||
import csv
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
default_inference_kwargs = {
|
||||
"calculate_loss": True,
|
||||
"all_classes": ["A", "B", "C", "D"],
|
||||
"language": "English",
|
||||
"pretrain": False,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
|
||||
def get_few_shot_data(data: List[Dict]):
|
||||
few_shot_data = []
|
||||
for i in data:
|
||||
few_shot_data.append(i["input"] + i["target"])
|
||||
return few_shot_data
|
||||
|
||||
|
||||
class MMLUDataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for MMLU dataset.
|
||||
Data source: https://github.com/hendrycks/test
|
||||
This dataset class will convert the original dataset into the inference dataset.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
files.sort()
|
||||
|
||||
for file in files:
|
||||
subject = file[0 : -len(f"_{split}.csv")].split("_")
|
||||
subject = " ".join([word.title() if word != "us" else "US" for word in subject])
|
||||
|
||||
file_dir = os.path.join(path, split, file)
|
||||
|
||||
dataset[split][subject] = {"data": [], "inference_kwargs": {}}
|
||||
|
||||
# It's been tested that each data sample in one subcategory have same inference arguments.
|
||||
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
|
||||
|
||||
if split == "test" and few_shot:
|
||||
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
|
||||
dataset["dev"][subject]["data"]
|
||||
)
|
||||
|
||||
with open(file_dir, encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
for row in reader:
|
||||
assert len(row) == 6
|
||||
choices = f"A. {row[1]}\nB. {row[2]}\nC. {row[3]}\nD. {row[4]}"
|
||||
data_sample = {
|
||||
"dataset": "mmlu",
|
||||
"split": split,
|
||||
"category": subject,
|
||||
"instruction": f"The following is a single-choice question on {subject}. Answer the question by replying A, B, C or D.",
|
||||
"input": f"Question: {row[0]}\n{choices}\nAnswer: ",
|
||||
"output": "",
|
||||
"target": row[5],
|
||||
}
|
||||
|
||||
dataset[split][subject]["data"].append(data_sample)
|
||||
|
||||
return dataset
|
@@ -0,0 +1,248 @@
|
||||
# GPT Evaluation
|
||||
## Table of Contents
|
||||
- [Overview](#overview)
|
||||
- [GPT Evaluation](#gpt-evaluation)
|
||||
- [Evaluation Category](#evaluation-category)
|
||||
- [Evaluation Category Examples](#evaluation-category-examples)
|
||||
- [Evaluation Metrics](#evaluation-metrics)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Data Format](#data-format)
|
||||
- [Prompt](#prompt)
|
||||
- [Battle Prompt](#battle-prompt)
|
||||
- [Evaluation Prompt](#evaluation-prompt)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Configuration](#configuration)
|
||||
- [Evaluate](#evaluate)
|
||||
- [FAQ](#faq)
|
||||
- [Citations](#citations)
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
In this directory, we introduce how you can evaluate your model using GPTs. It is now available for evaluation of both Chinese and English capability and we provide the following functions:
|
||||
|
||||
* Compare the performance of two different models (battle).
|
||||
* Rate the model according to pre-defined metrics using prompting design.
|
||||
* Rate the model according to pre-defined metrics with additional reference answer using prompting design.
|
||||
|
||||
## GPT Evaluation
|
||||
|
||||
### Evaluation Category
|
||||
|
||||
Our evaluation pipeline can examine the model's capability using different categories of questions. The following table includes some example categories. You can add your own questions.
|
||||
|
||||
| Evaluation Category | Description |
|
||||
| :-----------------: | :----------------------------------------------------------- |
|
||||
| Brainstorming | Models are asked to generate a range of creative and diverse ideas according to the question. The capability of creativity is required. |
|
||||
| Chat | Models are asked to continue a multi-round dialogue given the roles involved. The capability of understanding, memorizing previous rounds of the dialogue and answering according to the persona provided is required. |
|
||||
| Generation | Models are asked to generate an email, letter, article, etc. The capability of generating texts in a high quality and human-written way is required. |
|
||||
| Open QA | Models are asked to answer an open QA question(without context provided). The capability of answering questions with the models' own knowledge base is required. |
|
||||
| Roleplay | Models are asked to play the role provided. The capability of engaging in the scenario and effectively interacting with the user is required. |
|
||||
|
||||
|
||||
### Evaluation Category Examples
|
||||
To better understand each evaluation category, here are some example questions provided. Example questions are in the `configs/gpt_evaluation/data` folder.
|
||||
|
||||
|
||||
| Evaluation Category | Chinese Example | English Example |
|
||||
| :-----------------: | :----------------------------------------------------------- | :----------------------------------------------------------- |
|
||||
| Brainstorming | 列举一些可以促进头发生长的食物。 | How do you properly chop an onion without crying? |
|
||||
| Chat | 基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。<br/>小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。 <br/>老李:你好,小张,我很乐意帮助你。你想问些什么? <br/>小张:我想知道如何确定鸡的品种和性别? <br/>老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗?<br/> 小张:<br/> | Complete a dialogue based on the following character information. Alex: A novice writer who is struggling to find inspiration and develop his writing skills. Emma: A successful author with many published works, providing guidance and advice to Alex.<br/>Alex: Hi Emma, I have been writing for a while now but can't seem to make any progress. Can you give me any advice? <br/>Emma: Hi Alex, sure. What kind of writing are you doing?<br/>Alex: I'm trying to write a novel, but I just can't seem to find any inspiration.<br/>Emma: <br/> |
|
||||
| Generation | 请为一家咖啡店编写一篇简短的广告语,吸引更多的顾客。 | Write a set of guidelines for first-time pet owners on how to properly care for a new puppy. |
|
||||
| Open QA | 解释什么是RNA病毒和DNA病毒。 | Explain the process of osmosis in biological systems. |
|
||||
| Roleplay | 我要你把我写的句子翻译成表情符号。我会写句子,你会用表情符号表达它。我只是想让你用表情符号来表达它。除了表情符号,我不希望你回复任何内容。当我需要用中文告诉你一些事情时,我会用 {} 这样的大括号括起来。我的第一句话是“{我的职业是消防员。}” | I want you to act as a rapper. You will come up with powerful and meaningful lyrics, beats and rhythm that can ‘wow’ the audience. Your lyrics should have an intriguing meaning and message which people can relate too. When it comes to choosing your beat, make sure it is catchy yet relevant to your words, so that when combined they make an explosion of sound everytime! My first request is "I need a rap song about finding strength within yourself." |
|
||||
|
||||
### Evaluation Metrics
|
||||
|
||||
GPT evaluation uses GPT models to evaluate the prediction of different models and different pre-defined evaluation metrics are applied to different categories. The following table shows the 10 pre-defined evaluation metrics both in Chinese and English:
|
||||
|
||||
| Evaluation Metric | Prompt Words | CoT(Chain-of-Thought) |
|
||||
| :-------------------: | :----------------------------------------------------------- | :----------------------------------------------------------- |
|
||||
| 语言组织<br/>(Language organization) | 语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。</br></br>Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc. | 1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。<br/> 2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说<br/> 3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。<br/> 4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。<br/> 5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。<br/> 6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。</br></br>1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.<br>2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.<br>3. Determine if the answer is relevant to the question or topic and conveys a clear message.<br>4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.<br>5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.<br>6. Evaluate the linguistic organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good linguistic organization and 1 indicates very poor linguistic organization. |
|
||||
| 切题<br/>(Relevance) | 切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。</br></br>Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic. | 1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。<br/> 2. 阅读答案,确认答案是否直接回答了题目所问的问题。<br/> 3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。<br/> 4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。</br></br>1. Read the question to determine what the question asks and what aspects of the question need to be answered.<br>2. Read the answers to make sure that they directly answer the question asked.<br>3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.<br>4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all. |
|
||||
| 创意性<br/>(Creativity) | 创意性(1-5):某些头脑风暴问题可能需要答案具有创意,提出新的思路。</br></br>Creativity (1-5): Some brainstorming questions may require answers that are creative and suggest new ideas. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。<br/> 2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则创意性评分可能会受到影响。<br/> 3. 考虑答案中是否包含新颖的想法或独特的思路。答案可能与已知的解决方案有所重叠,但仍然可以被认为是有创意的,只要它提供了新的角度或方法来解决问题。<br/> 4. 根据答案的创意性,给出一个1到5的评分。如果答案缺乏创意,则应给出一个较低的评分。如果答案具有创意并提供了新的思路,应给出一个较高的评分。</br></br>1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.<br>2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the creativity score may be affected.<br>3. Consider whether the answer contains novel ideas or unique thoughts. An answer may overlap with a known solution and still be considered creative, as long as it offers a new perspective or approach to the problem.<br>4. Give a score of 1 to 5 depending on the creativity of the answer. If the answer lacks creativity, a lower score should be given. If the answer is creative and provides a new idea, a higher score should be given. |
|
||||
| 实用性<br/>(Practicality) | 实用性(1-5):某些头脑风暴问题可能需要答案提出实用的建议或解决方法。</br></br>Practicality (1-5): Some brainstorming questions may require answers to suggest practical suggestions or solutions. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。<br/> 2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则实用性评分可能会受到影响。<br/> 3. 考虑答案中提出的建议或解决方法是否实用并可行。答案可能看起来很好,但如果无法实现或应用,则实用性评分可能会受到影响。<br/> 4. 根据答案的实用性,给出一个1到5的评分。如果答案缺乏实用性,则应给出一个较低的评分。如果答案提出了实用的建议或解决方法,并且可以很好地解决问题,则应给出一个较高的评分。</br></br>1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.<br>2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the practicality score may be affected.<br>3. Consider whether the suggestions or solutions presented in the answer are practical and workable. The answer may look good, but if it cannot be implemented or applied, the practicality score may be affected.<br>4. Give a score of 1 to 5 depending on the practicality of the answer. If the answer lacks practicality, a lower score should be given. If the answer makes a practical suggestion or solution and solves the problem well, a higher score should be given. |
|
||||
| 正确性<br/>(Correctness) | 正确性(1-5):正确性(1-5):答案是否正确。</br></br> Correctness (1-5): whether the answer is correct or not. | 1. 仔细阅读题目,尝试自己回答该问题。<br/>2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。<br/><br/>1. Read the question carefully and try to answer the question yourself. <br/>2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded. |
|
||||
| 自然<br/>(Naturalness) | 自然(1-5):答案是否自然,并且符合问题给定的身份。</br></br>Naturalness (1-5): whether the answer is natural and fits the identity given by the question. | 1. 阅读题目,确定题目提供的身份信息。<br/> 2. 检查答案内容是否符合题目给定的身份。<br/> 3. 根据以上因素,对该回答的自然性进行打分,分数从1到5,其中1表示不自然,5表示非常自然,并符合问题给定的身份。</br></br>1. Read the question and determine the identity information provided in the question.<br>2. Check whether the content of the answer matches the identity given in the question.<br>3. Based on the above factors, score the naturalness of the response on a scale from 1 to 5, where 1 means unnatural and 5 means very natural and in accordance with the identity given in the question. |
|
||||
| 参与感<br/>(Engagingness) | 参与感(1-5):答案是否对前面的对话内容做出了恰当的反应,是否理解对话的语境和背景。</br></br>Engagingness (1-5): whether the answer responds appropriately to the content of the preceding conversation and whether it understands the context and background of the conversation. | 1. 阅读题目,确定对话的语境和背景。<br/> 2. 检查答案是否充分理解对话的语境和背景,能否自然地融入到对话中而不显得突兀。<br/> 3. 根据以上因素,对该回答的参与感进行打分,分数从1到5,其中1表示没有参与感,5表示非常有参与感,并且恰当地理解了对话的语境和背景。</br></br>1. Read the questions to determine the context and background of the dialogue.<br>2. Check that the answer fully understands the context and background of the conversation and that it fits naturally into the conversation without seeming abrupt.<br>3. Based on the above factors, rate the response's engagement on a scale from 1 to 5, where 1 means not engaged and 5 means very engaged and appropriately understands the context and background of the conversation. |
|
||||
| 合理性<br/>(Reasonableness) | 合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。</br></br>Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context. | 1. 阅读题目,确定对话的主题以及问题期望的回答方向。<br/> 2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。<br/> 3. 根据以上因素,对该回答的合理性进行打分,分数从1到5,其中1表示不合理,5表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。</br></br>1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.<br>2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.<br>3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense. |
|
||||
| 多样性<br/>(Diversity) | 多样性(1-5):答案使用语言是否优美,具有有一定的创造性和想象力。然而,回答也应该保持合理和适度,不要过于夸张或离题。</br></br>Diversity (1-5): Whether the answers use beautiful language and have some creativity and imagination. However, answers should also be kept reasonable and moderate, not overly exaggerated or off-topic. | 1. 仔细阅读整个回答,确保完全理解回答所表达的内容和主题。<br/> 2. 在阅读回答的同时,注意语言的质量,例如措辞是否正确,语言是否生动等。<br/> 3. 检查回答的创造性和想象力,看看回答是否能够吸引人阅读下去。<br/> 4. 检查回答的合理性和适度,看看回答是否夸张或离题。5. 将多样性的评分打分在1到5之间,5分表示回答的质量很好,能够吸引人阅读,1分表示回答的内容生硬或者有离题的问题。</br></br>1. Read the entire response carefully to ensure that you fully understand the content and theme expressed in the response.<br>2. While reading the response, pay attention to the quality of the language, such as whether the wording is correct and the language is vivid.<br>3. Check the creativity and imagination of the response to see if the response is engaging to read on.<br>4. Check the reasonableness and appropriateness of the responses to see if the responses are exaggerated or off-topic.<br>5. Rate the diversity on a scale of 1 to 5, with a 5 indicating a good quality response that is engaging to read and a 1 indicating a raw response or a question that is off-topic. |
|
||||
| 保真度<br/>(Fidelity) | 保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。</br></br>Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting. | 1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。<br/> 阅读题目的请求,确认回答请求时需要注意的细节。<br/> 3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。<br/> 4. 结合以上评估结果给出保真度的评分,范围从1到5分,其中1分表示回答与角色设定完全不符,5分表示回答完全符合角色设定且满足给定请求。</br></br>1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.<br>2. Read the question's request and confirm the details that need to be taken into account when answering the request.<br>3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.<br>4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request. |
|
||||
|
||||
GPT models evaluate the quality of model predictions based on the given prompt words and gives a score between 1-5.
|
||||
|
||||
> **NOTE 1:** You can find all the prompt words and CoT(Chain-of-Thought) in `configs/gpt_evaluation/prompt/evaluation_prompt`.
|
||||
|
||||
> **NOTE 2:** To add customized metrics, you can refer to [FAQ](#faq).
|
||||
|
||||
## Evaluation Process
|
||||
|
||||
### Data Format
|
||||
|
||||
A JSON file contains one list. Each element in the list is a target answer / prediction record for one instruction / question.
|
||||
An element should have the following fields:
|
||||
|
||||
* `category` (str, compulsory): The category of the instruction / question.
|
||||
* `instruction` (str, compulsory): The instruction / question for the LLM.
|
||||
* `input` (str, optional): The additional context of the instruction / question.
|
||||
* `output` (str, optional): The model output of the instruction, models will fill in this field during inference time.
|
||||
* `target` (str, optional): The target answer for the instruction.
|
||||
* `id` (int, compulsory): The ID of the instruction / question.
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"category": "brainstorming",
|
||||
"instruction": "请问如何制作一份美味的西红柿炒鸡蛋?",
|
||||
"input": "",
|
||||
"output": "",
|
||||
"target": "",
|
||||
"id": 1
|
||||
},
|
||||
{
|
||||
"category": "chat",
|
||||
"instruction": "基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。",
|
||||
"input": "小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。 老李:你好,小张,我很乐意帮助你。你想问些什么? 小张:我想知道如何确定鸡的品种和性别? 老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗? 小张:",
|
||||
"output": "",
|
||||
"target": "",
|
||||
"id": 2
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Prompt
|
||||
|
||||
#### Battle Prompt
|
||||
|
||||
The following is the Chinese battle prompt. In the battle prompt, the question and answers from two different models are fed into the prompt template. You can find example battle prompt files for Chinese and English in `configs/gpt_evaluation/prompt/battle_prompt`.
|
||||
|
||||
```json
|
||||
{
|
||||
"id": 1,
|
||||
"system_prompt": "你是一个检查回答质量的好助手。",
|
||||
"prompt_template": "[问题]\n{question}\n\n[1号AI助手的答案]\n{answer_1}\n\n[1号AI助手答案终止]\n\n[2号AI助手的答 案]\n{answer_2}\n\n[2号AI助手答案终止]\n\n[要求]\n{prompt}\n\n",
|
||||
"prompt": "我们需要你评价这两个AI助手回答的性能。\n请对他们的回答的有用性、相关性、准确性、详细程度进行评分。每个AI助手都会得到一个1到10分的总分,分数越高表示整体表现越好。\n请首先输出一行,该行只包含两个数值,分别表示1号和2号AI助手的分数。这两个分数之间要有一个空格。在随后的一行中,请对你的评价作出全面的解释,避免任何潜在的偏见,并确保AI助手回答的顺序不会影响您的判断。"
|
||||
}
|
||||
```
|
||||
|
||||
#### Evaluation Prompt
|
||||
|
||||
The following is an example of a Chinese GPT evaluation prompt. In an evaluation prompt, you should define your metrics in `metrics` and provide CoT(Chain-of-Thought) in `CoT`. You can find example evaluation prompt files for Chinese and English in `configs/gpt_evaluation/prompt/evaluation_prompt`.
|
||||
|
||||
```json
|
||||
{
|
||||
"brainstorming": {
|
||||
"id": 1,
|
||||
"category": "brainstorming",
|
||||
"metrics": {
|
||||
"language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。"
|
||||
},
|
||||
"CoT": {
|
||||
"language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:"
|
||||
},
|
||||
"prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`"metrics"`: the metrics that can be used in GPT evaluation. This field determines which metrics can be added to your config file.
|
||||
|
||||
`"CoT"`: evaluation steps you prompt to GPT models for each metric defined in `"metrics"`.
|
||||
|
||||
### Evaluation
|
||||
|
||||
#### Configuration
|
||||
|
||||
The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics in key `GPT`. You can find an example English config file in `configs/gpt_evaluation/config/config_en.json`.
|
||||
|
||||
```json
|
||||
{
|
||||
"language": "cn",
|
||||
"category": {
|
||||
"brainstorming": {
|
||||
"GPT": [
|
||||
"language organization",
|
||||
"relevance",
|
||||
"creativity",
|
||||
"practicality",
|
||||
"reasonableness"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`"language"`: the language used to evaluate the model capability. We only support Chinese `"cn"` for now.
|
||||
|
||||
`"category"`: the category/categories needed to evaluate the model capability.
|
||||
|
||||
`"GPT"`: the metrics you want to use for GPT evaluation.
|
||||
|
||||
|
||||
#### Evaluate
|
||||
|
||||
After setting the configuration file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`. If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using automatic metrics and GPT models.
|
||||
|
||||
An example script is provided as follows:
|
||||
|
||||
```shell
|
||||
python eval.py \
|
||||
--config_file "path to the config file" \
|
||||
--battle_prompt_file "path to the prompt file for battle" \
|
||||
--gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \
|
||||
--target_file "path to the target answer file" \
|
||||
--answer_file_list "path to the answer files of at most 2 models" \
|
||||
--model_name_list "the names of at most 2 models" \
|
||||
--gpt_model "which GPT model to use for evaluation" \
|
||||
--save_path "path to save results" \
|
||||
--openai_key "your openai key" \
|
||||
```
|
||||
|
||||
If you want GPT evaluation with reference, you can add an argument `--gpt_with_reference`, but make sure the reference file have target answers.
|
||||
|
||||
## FAQ
|
||||
|
||||
<details><summary><b>How can I add a new GPT evaluation metric?</b></summary>
|
||||
|
||||
For example, if you want to add a new metric `persuasiveness` into category `brainstorming`, you should add the metric definition and its corresponding CoT(Chain-of-thought) in the evaluation prompt file in `prompt/evaluation_promt`. The CoT can be generated using ChatGPT. You can prompt ChatGPT to generate evaluation steps for the new metric.
|
||||
|
||||
```json
|
||||
{
|
||||
"brainstorming": {
|
||||
"id": 1,
|
||||
"category": "brainstorming",
|
||||
"metrics": {
|
||||
"persuasiveness": "persuasiveness(1-5):a short description for persuasiveness"
|
||||
},
|
||||
"CoT": {
|
||||
"persuasiveness": "CoT for persuasiveness\n\npersuasiveness:"
|
||||
},
|
||||
"prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@misc{vicuna2023,
|
||||
title = {Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90\%* ChatGPT Quality},
|
||||
url = {https://vicuna.lmsys.org},
|
||||
author = {Chiang, Wei-Lin and Li, Zhuohan and Lin, Zi and Sheng, Ying and Wu, Zhanghao and Zhang, Hao and Zheng, Lianmin and Zhuang, Siyuan and Zhuang, Yonghao and Gonzalez, Joseph E. and Stoica, Ion and Xing, Eric P.},
|
||||
month = {March},
|
||||
year = {2023}
|
||||
}
|
||||
|
||||
@misc{liu2023geval,
|
||||
title={G-Eval: NLG Evaluation using GPT-4 with Better Human Alignment},
|
||||
author={Yang Liu and Dan Iter and Yichong Xu and Shuohang Wang and Ruochen Xu and Chenguang Zhu},
|
||||
year={2023},
|
||||
eprint={2303.16634},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
@@ -0,0 +1,3 @@
|
||||
from .dataset_evaluator import DatasetEvaluator
|
||||
|
||||
__all__ = ["DatasetEvaluator"]
|
@@ -0,0 +1,269 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"]
|
||||
LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"]
|
||||
CombinedMetrics = ["combined_single_choice_accuracy"]
|
||||
OtherMetrics = [
|
||||
"f1_score",
|
||||
"f1_zh_score",
|
||||
"rouge_score",
|
||||
"rouge_zh_score",
|
||||
"retrieval_score",
|
||||
"retrieval_zh_score",
|
||||
"classification_score",
|
||||
"code_sim_score",
|
||||
"count_score",
|
||||
"multi_choice_accuracy",
|
||||
"math_equivalence",
|
||||
"single_choice_accuracy",
|
||||
]
|
||||
|
||||
|
||||
class DatasetEvaluator(object):
|
||||
"""
|
||||
Dataset evaluator.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _calculate_label_metrics(self, metric: str, category: str):
|
||||
"""Calculate label-based metrics."""
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
|
||||
str_label_map = {
|
||||
choice: idx for idx, choice in enumerate(self.data[category]["inference_kwargs"]["all_classes"])
|
||||
}
|
||||
|
||||
references = [str_label_map[sample["target"]] for sample in self.data[category]["data"]]
|
||||
[sample["output"] for sample in self.data[category]["data"]]
|
||||
|
||||
flag = False
|
||||
softmaxs = []
|
||||
for i, sample in enumerate(self.data[category]["data"]):
|
||||
if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))):
|
||||
if not flag:
|
||||
print(
|
||||
f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}."
|
||||
)
|
||||
flag = True
|
||||
score = 0
|
||||
for ref in sample["target"]:
|
||||
score = max(
|
||||
score,
|
||||
metric_helper.single_choice_accuracy(
|
||||
sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]
|
||||
),
|
||||
)
|
||||
softmaxs.append(references[i] if score == 1 else -1)
|
||||
else:
|
||||
softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values()))))
|
||||
|
||||
references = np.array(references)
|
||||
softmaxs = np.array(softmaxs)
|
||||
scores = np.sum(references == softmaxs) / len(self.data[category]["data"]) * 100
|
||||
|
||||
self.evaluation_results[metric][category] = (scores, len(self.data[category]["data"]))
|
||||
self.evaluation_results[metric]["ALL"] += scores * weight
|
||||
|
||||
def _calculate_combined_metrics(self, metric: str, category: str):
|
||||
"""Calculate combined metrics."""
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
|
||||
references = [sample["target"] for sample in self.data[category]["data"]]
|
||||
predictions = [sample["output"] for sample in self.data[category]["data"]]
|
||||
|
||||
str_label_map = {
|
||||
choice: idx for idx, choice in enumerate(self.data[category]["inference_kwargs"]["all_classes"])
|
||||
}
|
||||
|
||||
references_labels = [str_label_map[sample["target"][0]] for sample in self.data[category]["data"]]
|
||||
predictions = [sample["output"] for sample in self.data[category]["data"]]
|
||||
|
||||
flag = False
|
||||
softmaxs = []
|
||||
for i, sample in enumerate(self.data[category]["data"]):
|
||||
if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))):
|
||||
if not flag:
|
||||
print(
|
||||
f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}."
|
||||
)
|
||||
flag = True
|
||||
score = 0
|
||||
for ref in sample["target"]:
|
||||
score = max(
|
||||
score,
|
||||
metric_helper.single_choice_accuracy(
|
||||
sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]
|
||||
),
|
||||
)
|
||||
softmaxs.append(references[i] if score == 1 else -1)
|
||||
else:
|
||||
softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values()))))
|
||||
|
||||
metric_method = eval("metric_helper." + metric)
|
||||
|
||||
total_score = 0.0
|
||||
for prediction, reference, references_label, softmax in zip(
|
||||
predictions, references, references_labels, softmaxs
|
||||
):
|
||||
score = 0.0
|
||||
|
||||
for ref in reference:
|
||||
score = max(
|
||||
score,
|
||||
metric_method(prediction, ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]),
|
||||
)
|
||||
if references_label == softmax:
|
||||
score = 1
|
||||
|
||||
total_score += score
|
||||
total_score = total_score * 100 / len(self.data[category]["data"])
|
||||
|
||||
self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"]))
|
||||
self.evaluation_results[metric]["ALL"] += total_score * weight
|
||||
|
||||
def _calculate_other_metrics(self, metric: str, category: str):
|
||||
"""Calculate other metrics."""
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
|
||||
references = [sample["target"] for sample in self.data[category]["data"]]
|
||||
predictions = [sample["output"] for sample in self.data[category]["data"]]
|
||||
|
||||
metric_method = eval("metric_helper." + metric)
|
||||
|
||||
total_score = 0.0
|
||||
for prediction, reference in zip(predictions, references):
|
||||
score = 0.0
|
||||
for ref in reference:
|
||||
score = max(
|
||||
score,
|
||||
metric_method(prediction, ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]),
|
||||
)
|
||||
total_score += score
|
||||
total_score = total_score * 100 / len(predictions)
|
||||
|
||||
self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"]))
|
||||
self.evaluation_results[metric]["ALL"] += total_score * weight
|
||||
|
||||
def _calculate_loss_metrics(self, metric: str, category: str):
|
||||
"""Calculate perplexity."""
|
||||
if metric == "perplexity":
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
losses = [min(sample["loss"]) for sample in self.data[category]["data"]]
|
||||
perplexity = np.mean(np.exp(np.array(losses)))
|
||||
|
||||
self.evaluation_results["perplexity"][category] = (perplexity, len(self.data[category]["data"]))
|
||||
self.evaluation_results["perplexity"]["ALL"] += perplexity * weight
|
||||
elif metric == "ppl_score":
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
losses = [min(sample["loss"]) for sample in self.data[category]["data"]]
|
||||
perplexity_score = np.mean(np.exp(-np.array(losses))) * 100
|
||||
|
||||
self.evaluation_results["ppl_score"][category] = (perplexity_score, len(self.data[category]["data"]))
|
||||
self.evaluation_results["ppl_score"]["ALL"] += perplexity_score * weight
|
||||
elif metric == "ppl_score_over_choices" and self.data[category]["inference_kwargs"]["all_classes"] is not None:
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
loss_over_choices = [sample["loss_over_choices"] for sample in self.data[category]["data"]]
|
||||
perplexity_score_over_choices = np.mean(np.exp(-np.array(loss_over_choices))) * 100
|
||||
|
||||
self.evaluation_results["ppl_score_over_choices"][category] = (
|
||||
perplexity_score_over_choices,
|
||||
len(self.data[category]["data"]),
|
||||
)
|
||||
self.evaluation_results["ppl_score_over_choices"]["ALL"] += perplexity_score_over_choices * weight
|
||||
elif metric == "per_byte_perplexity":
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]]
|
||||
perplexity = np.mean(np.exp(np.array(losses) / np.array(self.N_bytes[category])))
|
||||
|
||||
self.evaluation_results["per_byte_perplexity"][category] = perplexity
|
||||
self.evaluation_results["per_byte_perplexity"]["ALL"] += perplexity * weight
|
||||
elif metric == "per_byte_ppl_score":
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]]
|
||||
perplexity_score = np.mean(np.exp(-np.array(losses) / np.array(self.N_bytes[category]))) * 100
|
||||
|
||||
self.evaluation_results["per_byte_ppl_score"][category] = perplexity_score
|
||||
self.evaluation_results["per_byte_ppl_score"]["ALL"] += perplexity_score * weight
|
||||
|
||||
def _evaluate(self):
|
||||
"""Calculate and return evaluation results"""
|
||||
|
||||
for metric in self.metrics:
|
||||
pbar = tqdm.tqdm(
|
||||
desc=f"{self.dataset_name}-{metric}-{self.model_name}", total=len(self.suggested_categories[metric])
|
||||
)
|
||||
|
||||
if metric in LabelBasedMetrics:
|
||||
for category in self.suggested_categories[metric]:
|
||||
self._calculate_label_metrics(metric, category)
|
||||
pbar.update(1)
|
||||
elif metric in LossBasedMetrics:
|
||||
for category in self.suggested_categories[metric]:
|
||||
self._calculate_loss_metrics(metric, category)
|
||||
pbar.update(1)
|
||||
elif metric in CombinedMetrics:
|
||||
for category in self.suggested_categories[metric]:
|
||||
self._calculate_combined_metrics(metric, category)
|
||||
pbar.update(1)
|
||||
elif metric in OtherMetrics:
|
||||
for category in self.suggested_categories[metric]:
|
||||
self._calculate_other_metrics(metric, category)
|
||||
pbar.update(1)
|
||||
|
||||
return self.evaluation_results
|
||||
|
||||
def get_evaluation_results(self, data: List[Dict], dataset_name: str, model_name: str, metrics: List[str]):
|
||||
"""
|
||||
Evaluate inference data on the given metrics.
|
||||
|
||||
Args:
|
||||
data: Data to be evaluated.
|
||||
dataset_name: Name of the dataset
|
||||
model_name: Name of the model
|
||||
metrics: Metrics used to evaluate.
|
||||
|
||||
"""
|
||||
self.data = data
|
||||
self.dataset_name = dataset_name
|
||||
self.model_name = model_name
|
||||
self.categories = list(data.keys())
|
||||
self.metrics = metrics
|
||||
|
||||
self.evaluation_results = {
|
||||
metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics
|
||||
}
|
||||
|
||||
self.total_length = 0
|
||||
self.total_single_choices = 0
|
||||
for value in self.data.values():
|
||||
self.total_length += len(value["data"])
|
||||
if value["inference_kwargs"]["all_classes"] is not None:
|
||||
self.total_single_choices += len(value["data"])
|
||||
|
||||
self.metric_total_length = {metric: 0 for metric in self.metrics}
|
||||
self.suggested_categories = {metric: [] for metric in self.metrics}
|
||||
|
||||
for metric in self.metrics:
|
||||
self.suggested_categories[metric] = metric_helper.metrics4subcategory[self.dataset_name][metric]
|
||||
if "ALL" in self.suggested_categories[metric]:
|
||||
self.suggested_categories[metric] = self.categories
|
||||
self.metric_total_length[metric] = self.total_length
|
||||
continue
|
||||
for category in self.suggested_categories[metric]:
|
||||
self.metric_total_length[metric] += len(self.data[category]["data"])
|
||||
|
||||
if "per_byte_perplexity" in self.metrics or "per_byte_ppl_score" in self.metrics:
|
||||
self.N_bytes = {category: [] for category in self.categories}
|
||||
for category in self.categories:
|
||||
samples = self.data[category]["data"]
|
||||
for sample in samples:
|
||||
self.N_bytes[category].append(sample["byte_num"][0])
|
||||
|
||||
return self._evaluate()
|
@@ -0,0 +1,623 @@
|
||||
# Code adapted from https://github.com/THUDM/LongBench/blob/main/metrics.py
|
||||
# Code adapted from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
|
||||
# Code adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/evaluation.py
|
||||
|
||||
import difflib
|
||||
import re
|
||||
import string
|
||||
from collections import Counter
|
||||
|
||||
import jieba
|
||||
from fuzzywuzzy import fuzz
|
||||
from rouge import Rouge
|
||||
|
||||
metrics4subcategory = {
|
||||
"pretrain": {
|
||||
"perplexity": ["ALL"],
|
||||
"ppl_score": ["ALL"],
|
||||
"per_byte_perplexity": ["ALL"],
|
||||
"per_byte_ppl_score": ["ALL"],
|
||||
},
|
||||
# The commented are non 4-choice questions.
|
||||
"agieval": {
|
||||
"combined_single_choice_accuracy": [
|
||||
# "lsat-ar",
|
||||
# "lsat-lr",
|
||||
# "lsat-rc",
|
||||
"logiqa-en",
|
||||
"sat-math",
|
||||
"sat-en",
|
||||
# "aqua-rat",
|
||||
"sat-en-without-passage",
|
||||
"gaokao-english",
|
||||
"logiqa-zh",
|
||||
"gaokao-chinese",
|
||||
"gaokao-geography",
|
||||
"gaokao-history",
|
||||
"gaokao-biology",
|
||||
"gaokao-chemistry",
|
||||
],
|
||||
"first_token_accuracy": [
|
||||
# "lsat-ar",
|
||||
# "lsat-lr",
|
||||
# "lsat-rc",
|
||||
"logiqa-en",
|
||||
"sat-math",
|
||||
"sat-en",
|
||||
# "aqua-rat",
|
||||
"sat-en-without-passage",
|
||||
"gaokao-english",
|
||||
"logiqa-zh",
|
||||
"gaokao-chinese",
|
||||
"gaokao-geography",
|
||||
"gaokao-history",
|
||||
"gaokao-biology",
|
||||
"gaokao-chemistry",
|
||||
],
|
||||
"single_choice_accuracy": [
|
||||
# "lsat-ar",
|
||||
# "lsat-lr",
|
||||
# "lsat-rc",
|
||||
"logiqa-en",
|
||||
"sat-math",
|
||||
"sat-en",
|
||||
# "aqua-rat",
|
||||
"sat-en-without-passage",
|
||||
"gaokao-english",
|
||||
"logiqa-zh",
|
||||
"gaokao-chinese",
|
||||
"gaokao-geography",
|
||||
"gaokao-history",
|
||||
"gaokao-biology",
|
||||
"gaokao-chemistry",
|
||||
],
|
||||
"multi_choice_accuracy": ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"],
|
||||
"math_equivalence": ["gaokao-mathcloze", "math"],
|
||||
"perplexity": ["ALL"],
|
||||
"ppl_score_over_choices": [
|
||||
"lsat-ar",
|
||||
"lsat-lr",
|
||||
"lsat-rc",
|
||||
"logiqa-en",
|
||||
"sat-math",
|
||||
"sat-en",
|
||||
"aqua-rat",
|
||||
"sat-en-without-passage",
|
||||
"gaokao-english",
|
||||
"logiqa-zh",
|
||||
"jec-qa-kd",
|
||||
"jec-qa-ca",
|
||||
"gaokao-chinese",
|
||||
"gaokao-geography",
|
||||
"gaokao-history",
|
||||
"gaokao-biology",
|
||||
"gaokao-chemistry",
|
||||
"gaokao-physics",
|
||||
"gaokao-mathqa",
|
||||
],
|
||||
"ppl_score": ["ALL"],
|
||||
},
|
||||
"cmmlu": {
|
||||
"first_token_accuracy": ["ALL"],
|
||||
"single_choice_accuracy": ["ALL"],
|
||||
"perplexity": ["ALL"],
|
||||
"ppl_score_over_choices": ["ALL"],
|
||||
"ppl_score": ["ALL"],
|
||||
},
|
||||
"gaokaobench": {
|
||||
"combined_single_choice_accuracy": [
|
||||
"English MCQs",
|
||||
"Biology MCQs",
|
||||
"Chemistry MCQs",
|
||||
"History MCQs",
|
||||
"Math I MCQs",
|
||||
"Math II MCQs",
|
||||
"Political Science MCQs",
|
||||
],
|
||||
"first_token_accuracy": [
|
||||
"English MCQs",
|
||||
"Biology MCQs",
|
||||
"Chemistry MCQs",
|
||||
"History MCQs",
|
||||
"Math I MCQs",
|
||||
"Math II MCQs",
|
||||
"Political Science MCQs",
|
||||
],
|
||||
"single_choice_accuracy": [
|
||||
"English MCQs",
|
||||
"Biology MCQs",
|
||||
"Chemistry MCQs",
|
||||
"History MCQs",
|
||||
"Math I MCQs",
|
||||
"Math II MCQs",
|
||||
"Political Science MCQs",
|
||||
],
|
||||
"multi_choice_accuracy": [
|
||||
"Chinese Lang and Usage MCQs",
|
||||
"Chinese Modern Lit",
|
||||
"English Fill in Blanks",
|
||||
"English Reading Comp",
|
||||
"Geography MCQs",
|
||||
"Physics MCQs",
|
||||
"English Cloze Test",
|
||||
],
|
||||
"math_equivalence": ["Math I Fill-in-the-Blank", "Math II Fill-in-the-Blank"],
|
||||
"rouge_score": ["English Language Cloze Passage"],
|
||||
"rouge_zh_score": [
|
||||
"Chinese Language Famous Passages and Sentences Dictation",
|
||||
"Chemistry Open-ended Questions",
|
||||
"History Open-ended Questions",
|
||||
"Biology Open-ended Questions",
|
||||
"Political Science Open-ended Questions",
|
||||
"English Language Error Correction",
|
||||
"Chinese Language Language and Writing Skills Open-ended Questions",
|
||||
"Math II Open-ended Questions",
|
||||
"Chinese Language Literary Text Reading",
|
||||
"Chinese Language Ancient Poetry Reading",
|
||||
"Chinese Language Classical Chinese Reading",
|
||||
"Physics Open-ended Questions",
|
||||
"Math I Open-ended Questions",
|
||||
"Geography Open-ended Questions",
|
||||
"Chinese Language Practical Text Reading",
|
||||
],
|
||||
"perplexity": ["ALL"],
|
||||
"ppl_score_over_choices": ["ALL"],
|
||||
"ppl_score": ["ALL"],
|
||||
},
|
||||
"longbench": {
|
||||
"f1_score": ["hotpotqa", "2wikimqa", "musique", "narrativeqa", "qasper", "multifieldqa_en", "triviaqa"],
|
||||
"f1_zh_score": ["multifieldqa_zh"],
|
||||
"rouge_score": ["gov_report", "qmsum", "multi_news", "samsum"],
|
||||
"rouge_zh_score": ["dureader", "vcsum"],
|
||||
"retrieval_score": ["passage_retrieval_en"],
|
||||
"retrieval_zh_score": ["passage_retrieval_zh"],
|
||||
"classification_score": ["trec", "lsht"],
|
||||
"code_sim_score": ["lcc", "repobench-p"],
|
||||
"count_score": ["passage_count"],
|
||||
"perplexity": ["ALL"],
|
||||
"ppl_score": ["ALL"],
|
||||
},
|
||||
"mmlu": {
|
||||
"first_token_accuracy": ["ALL"],
|
||||
"single_choice_accuracy": ["ALL"],
|
||||
"accuracy": ["ALL"],
|
||||
"perplexity": ["ALL"],
|
||||
"ppl_score_over_choices": ["ALL"],
|
||||
"ppl_score": ["ALL"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _fix_fracs(string):
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string):
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
a = int(a)
|
||||
b = int(b)
|
||||
assert string == "{}/{}".format(a, b)
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except:
|
||||
return string
|
||||
|
||||
|
||||
def _remove_right_units(string):
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string):
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def _strip_string(string):
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
# print(string)
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
# print(string)
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
# print(string)
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
# print(string)
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\right", "")
|
||||
# print(string)
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = _remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace("\%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split("=")) == 2:
|
||||
if len(string.split("=")[0]) <= 2:
|
||||
string = string.split("=")[1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = _fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def parse_math_answer(raw_string):
|
||||
def remove_boxed(s):
|
||||
left = "\\boxed{"
|
||||
try:
|
||||
assert s[: len(left)] == left
|
||||
assert s[-1] == "}"
|
||||
answer = s[len(left) : -1]
|
||||
if "=" in answer:
|
||||
answer = answer.split("=")[-1].lstrip(" ")
|
||||
return answer
|
||||
except:
|
||||
return None
|
||||
|
||||
def last_boxed_only_string(string):
|
||||
idx = string.rfind("\\boxed")
|
||||
if idx < 0:
|
||||
idx = string.rfind("\\fbox")
|
||||
if idx < 0:
|
||||
return None
|
||||
i = idx
|
||||
right_brace_idx = None
|
||||
num_left_braces_open = 0
|
||||
while i < len(string):
|
||||
if string[i] == "{":
|
||||
num_left_braces_open += 1
|
||||
if string[i] == "}":
|
||||
num_left_braces_open -= 1
|
||||
if num_left_braces_open == 0:
|
||||
right_brace_idx = i
|
||||
break
|
||||
i += 1
|
||||
|
||||
if right_brace_idx == None:
|
||||
retval = None
|
||||
else:
|
||||
retval = string[idx : right_brace_idx + 1]
|
||||
|
||||
return retval
|
||||
|
||||
def get_answer_with_dollar_sign(s):
|
||||
first_pattern = "\$(.*)\$"
|
||||
last_match = None
|
||||
matches = re.findall(first_pattern, s)
|
||||
if matches:
|
||||
last_match = matches[-1]
|
||||
if "=" in last_match:
|
||||
last_match = last_match.split("=")[-1].lstrip(" ")
|
||||
return last_match
|
||||
|
||||
def get_answer_without_dollar_sign(s):
|
||||
last_match = None
|
||||
if "=" in s:
|
||||
last_match = s.split("=")[-1].lstrip(" ").rstrip(".")
|
||||
if "\\n" in last_match:
|
||||
last_match = last_match.split("\\n")[0]
|
||||
else:
|
||||
pattern = "(?:\\$)?\d+(?:\.\d+)?(?![\w\d])"
|
||||
matches = re.findall(pattern, s)
|
||||
if matches:
|
||||
last_match = matches[-1]
|
||||
return last_match
|
||||
|
||||
if "\\boxed" in raw_string:
|
||||
answer = remove_boxed(last_boxed_only_string(raw_string))
|
||||
else:
|
||||
answer = get_answer_with_dollar_sign(raw_string)
|
||||
if not answer:
|
||||
answer = get_answer_without_dollar_sign(raw_string)
|
||||
return answer
|
||||
|
||||
|
||||
def math_equivalence(prediction, reference, **kwargs):
|
||||
prediction = parse_math_answer(prediction)
|
||||
|
||||
if prediction is None and reference is None:
|
||||
print("WARNING: Both None")
|
||||
return False
|
||||
|
||||
if prediction is None or reference is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
ss1 = _strip_string(prediction)
|
||||
ss2 = _strip_string(reference)
|
||||
return ss1 == ss2
|
||||
except:
|
||||
return prediction == reference
|
||||
|
||||
|
||||
def multi_choice_accuracy(prediction, reference, **kwargs):
|
||||
# Only find uppercase letters not surrounded by lowercase letters
|
||||
all_classes = kwargs.get("all_classes", None)
|
||||
if all_classes:
|
||||
pattern = f"(?<![a-z])[{all_classes[0]}-{all_classes[-1]}](?![a-z])"
|
||||
else:
|
||||
pattern = "(?<![a-z])[A-F](?![a-z])"
|
||||
|
||||
prediction = re.findall(pattern, prediction)
|
||||
reference = re.findall(pattern, reference)
|
||||
|
||||
prediction_set = set(prediction)
|
||||
reference_set = set(reference)
|
||||
|
||||
score = 0.0
|
||||
for p in prediction_set:
|
||||
if p not in reference_set:
|
||||
return 0.0
|
||||
else:
|
||||
score += 1 / len(reference_set)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def combined_single_choice_accuracy(prediction, reference, **kwargs):
|
||||
return single_choice_accuracy(prediction, reference, **kwargs)
|
||||
|
||||
|
||||
def single_choice_accuracy(prediction, reference, **kwargs):
|
||||
# Only find uppercase letters not surrounded by lowercase letters
|
||||
all_classes = kwargs.get("all_classes", None)
|
||||
if all_classes:
|
||||
pattern = f"(?<![a-z])[{all_classes[0]}-{all_classes[-1]}](?![a-z])"
|
||||
else:
|
||||
pattern = "(?<![a-z])[A-F](?![a-z])"
|
||||
|
||||
prediction = re.findall(pattern, prediction)[0:1]
|
||||
reference = re.findall(pattern, reference)
|
||||
|
||||
assert len(reference) == 1
|
||||
|
||||
prediction_set = set(prediction)
|
||||
reference_set = set(reference)
|
||||
|
||||
if prediction_set == reference_set:
|
||||
return 1.0
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def normalize_zh_answer(s):
|
||||
"""Lower text and remove punctuation, extra whitespace."""
|
||||
|
||||
def white_space_fix(text):
|
||||
return "".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
|
||||
all_punctuation = set(string.punctuation + cn_punctuation)
|
||||
return "".join(ch for ch in text if ch not in all_punctuation)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_punc(lower(s)))
|
||||
|
||||
|
||||
def count_score(prediction, reference, **kwargs):
|
||||
numbers = re.findall(r"\d+", prediction)
|
||||
right_num = 0
|
||||
for number in numbers:
|
||||
if str(number) == str(reference):
|
||||
right_num += 1
|
||||
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
||||
return float(final_score)
|
||||
|
||||
|
||||
def retrieval_score(prediction, reference, **kwargs):
|
||||
pattern = r"Paragraph (\d+)"
|
||||
matches = re.findall(pattern, reference)
|
||||
ground_truth_id = matches[0]
|
||||
numbers = re.findall(r"\d+", prediction)
|
||||
right_num = 0
|
||||
for number in numbers:
|
||||
if str(number) == str(ground_truth_id):
|
||||
right_num += 1
|
||||
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
||||
return float(final_score)
|
||||
|
||||
|
||||
def retrieval_zh_score(prediction, reference, **kwargs):
|
||||
pattern = r"段落(\d+)"
|
||||
matches = re.findall(pattern, reference)
|
||||
ground_truth_id = matches[0]
|
||||
numbers = re.findall(r"\d+", prediction)
|
||||
right_num = 0
|
||||
for number in numbers:
|
||||
if str(number) == str(ground_truth_id):
|
||||
right_num += 1
|
||||
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
||||
return float(final_score)
|
||||
|
||||
|
||||
def code_sim_score(prediction, reference, **kwargs):
|
||||
all_lines = prediction.lstrip("\n").split("\n")
|
||||
prediction = ""
|
||||
for line in all_lines:
|
||||
if ("`" not in line) and ("#" not in line) and ("//" not in line):
|
||||
prediction = line
|
||||
break
|
||||
return fuzz.ratio(prediction, reference) / 100
|
||||
|
||||
|
||||
def classification_score(prediction, reference, **kwargs):
|
||||
em_match_list = []
|
||||
all_classes = kwargs["all_classes"]
|
||||
for class_name in all_classes:
|
||||
if class_name in prediction:
|
||||
em_match_list.append(class_name)
|
||||
for match_term in em_match_list:
|
||||
if match_term in reference and match_term != reference:
|
||||
em_match_list.remove(match_term)
|
||||
if em_match_list != 0:
|
||||
if reference in em_match_list:
|
||||
score = 1.0 / len(em_match_list)
|
||||
else:
|
||||
score = 0.0
|
||||
else:
|
||||
best_match = None
|
||||
highest_similarity = 0
|
||||
for string in all_classes:
|
||||
similarity = difflib.SequenceMatcher(None, string, prediction).ratio()
|
||||
if similarity > highest_similarity:
|
||||
highest_similarity = similarity
|
||||
best_match = string
|
||||
score = float(best_match == reference)
|
||||
return score
|
||||
|
||||
|
||||
def rouge_score(prediction, reference, **kwargs):
|
||||
rouge = Rouge()
|
||||
try:
|
||||
scores = rouge.get_scores([prediction], [reference], avg=True)
|
||||
except:
|
||||
return 0.0
|
||||
return scores["rouge-l"]["f"]
|
||||
|
||||
|
||||
def rouge_zh_score(prediction, reference, **kwargs):
|
||||
prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
|
||||
reference = " ".join(list(jieba.cut(reference, cut_all=False)))
|
||||
score = rouge_score(prediction, reference)
|
||||
return score
|
||||
|
||||
|
||||
def _f1_score(prediction, reference, **kwargs):
|
||||
common = Counter(prediction) & Counter(reference)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction)
|
||||
recall = 1.0 * num_same / len(reference)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def f1_score(prediction, reference, **kwargs):
|
||||
normalized_prediction = normalize_answer(prediction)
|
||||
normalized_ground_truth = normalize_answer(reference)
|
||||
|
||||
prediction_tokens = normalized_prediction.split()
|
||||
ground_truth_tokens = normalized_ground_truth.split()
|
||||
return _f1_score(prediction_tokens, ground_truth_tokens)
|
||||
|
||||
|
||||
def f1_zh_score(prediction, reference, **kwargs):
|
||||
prediction_tokens = list(jieba.cut(prediction, cut_all=False))
|
||||
ground_truth_tokens = list(jieba.cut(reference, cut_all=False))
|
||||
prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
|
||||
ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
|
||||
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
|
||||
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
|
||||
return _f1_score(prediction_tokens, ground_truth_tokens)
|
110
applications/ColossalEval/colossal_eval/evaluate/evaluator.py
Normal file
110
applications/ColossalEval/colossal_eval/evaluate/evaluator.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import colossal_eval.evaluate.gpt_evaluate as gpt_evaluate
|
||||
|
||||
from .utils import get_data_per_category
|
||||
|
||||
|
||||
class Evaluator(object):
|
||||
"""
|
||||
A class named Evaluator includes GPT-3.5/GPT-4 evaluation
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
battle_prompt: Dict[str, Any],
|
||||
gpt_evaluation_prompt: Dict[str, Any],
|
||||
gpt_model: str,
|
||||
language: str,
|
||||
gpt_with_reference: bool,
|
||||
) -> None:
|
||||
self.params = params
|
||||
self.battle_prompt = battle_prompt
|
||||
self.gpt_evaluation_prompt = gpt_evaluation_prompt
|
||||
self.gpt_model = gpt_model
|
||||
self.language = language
|
||||
self.gpt_with_reference = gpt_with_reference
|
||||
self.gpt_evaluation_results = dict()
|
||||
self.battle_results = []
|
||||
|
||||
def battle(self, answers1: List[Dict], answers2: List[Dict]) -> None:
|
||||
"""
|
||||
Comparison between two models using GPT-4 as the reviewer.
|
||||
"""
|
||||
|
||||
self.battle_results = gpt_evaluate.battle(answers1, answers2, self.battle_prompt)
|
||||
|
||||
def evaluate(self, answers: List[Dict], targets: List[Dict], save_path: str, model_name: str) -> None:
|
||||
"""
|
||||
A comprehensive evaluation of the answers from the model.
|
||||
The function evaluates the model's performance from different perspectives
|
||||
using GPT-3.5, GPT-4, and off-the-shelf evaluation metrics.
|
||||
|
||||
The metrics will be decided by the config file.
|
||||
|
||||
"""
|
||||
|
||||
answers_per_category = get_data_per_category(answers, list(self.params.keys()))
|
||||
targets_per_category = get_data_per_category(targets, list(self.params.keys()))
|
||||
|
||||
# gpt evaluation
|
||||
for category in self.params:
|
||||
if len(answers_per_category[category]) == 0:
|
||||
print(f"Category {category} specified in your config doesn't have corresponding answers!")
|
||||
continue
|
||||
|
||||
if self.params[category].get("GPT", None) is None:
|
||||
continue
|
||||
|
||||
category_metrics = self.params[category]["GPT"]
|
||||
|
||||
prompt = self.gpt_evaluation_prompt.get(category, None)
|
||||
if prompt is None:
|
||||
print(f"No prompt for category {category}! Use prompt for category general now.")
|
||||
prompt = self.gpt_evaluation_prompt["general"]
|
||||
|
||||
self.gpt_evaluation_results[category] = gpt_evaluate.evaluate(
|
||||
answers_per_category[category],
|
||||
prompt,
|
||||
category_metrics,
|
||||
category,
|
||||
save_path,
|
||||
model_name,
|
||||
self.gpt_model,
|
||||
self.language,
|
||||
references=targets_per_category[category] if self.gpt_with_reference else None,
|
||||
)
|
||||
|
||||
def save(self, path: str, model_name_list: List[str]) -> None:
|
||||
"""
|
||||
Save evaluation results of GPT-3.5, GPT-4, and off-the-shelf evaluation metrics.
|
||||
|
||||
"""
|
||||
|
||||
if len(model_name_list) == 2:
|
||||
save_path = os.path.join(path, "gpt_evaluate", "battle_results")
|
||||
gpt_evaluate.save_battle_results(self.battle_results, model_name_list[0], model_name_list[1], save_path)
|
||||
else:
|
||||
if self.gpt_evaluation_results:
|
||||
# Save evaluation results for GPT evaluation metrics.
|
||||
gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
|
||||
gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
|
||||
|
||||
all_evaluations = gpt_evaluate.save_gpt_evaluation_results(
|
||||
model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path
|
||||
)
|
||||
|
||||
# Start to calculate scores and save statistics.
|
||||
gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
|
||||
gpt_evaluate.save_gpt_evaluation_statistics(
|
||||
model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path
|
||||
)
|
||||
|
||||
# Save charts and csv.
|
||||
gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
|
||||
gpt_evaluate.analyze_gpt_evaluation_statistics(
|
||||
gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path
|
||||
)
|
852
applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py
Normal file
852
applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py
Normal file
@@ -0,0 +1,852 @@
|
||||
import concurrent.futures
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import openai
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import tqdm
|
||||
from colossal_eval.utils import jdump, jload
|
||||
|
||||
ref_step_template = {
|
||||
"en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
|
||||
"cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n",
|
||||
}
|
||||
|
||||
ref_answer_template_general = {
|
||||
"en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n",
|
||||
"cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n",
|
||||
}
|
||||
|
||||
ref_answer_template_correctness = {
|
||||
"en": "\nA correct answer is as follows:\n\n{answer}\n\n",
|
||||
"cn": "\n标准答案如下:\n\n{answer}\n\n",
|
||||
}
|
||||
|
||||
|
||||
def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: int = 2048) -> Dict[str, Any]:
|
||||
"""
|
||||
Get battle evaluation from GPT-4.
|
||||
|
||||
Args:
|
||||
sys_prompt: prompt for the system.
|
||||
user_prompt: prompt for the user.
|
||||
id: id of the answers for comparison.
|
||||
max_tokens: the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
Returns:
|
||||
An evaluation of one comparison.
|
||||
"""
|
||||
|
||||
MAX_API_RETRY = 3
|
||||
for _ in range(MAX_API_RETRY):
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
{"role": "system", "content": sys_prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
],
|
||||
temperature=0.2,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
evaluation = response["choices"][0]["message"]["content"]
|
||||
return {"evaluation": evaluation, "id": id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
time.sleep(1)
|
||||
print(f"Evaluation {id} failed after {MAX_API_RETRY} retries.")
|
||||
return {"evaluation": "", "id": id}
|
||||
|
||||
|
||||
def parse_battle_score(evaluation: str) -> List[float]:
|
||||
"""
|
||||
Parse evaluation from GPT-4 and get the scores of model 1 and 2.
|
||||
|
||||
Args:
|
||||
evaluation: evaluation from GPT-4.
|
||||
|
||||
Returns:
|
||||
A score pair of two different model answers.
|
||||
"""
|
||||
|
||||
try:
|
||||
pattern = re.compile("([0-9]|10) out of 10")
|
||||
sp = re.findall(pattern, evaluation)
|
||||
if len(re.findall(pattern, evaluation)) == 2:
|
||||
return [float(sp[0]), float(sp[1])]
|
||||
|
||||
pattern = re.compile("a score of ([0-9]|10)")
|
||||
sp = re.findall(pattern, evaluation)
|
||||
if len(re.findall(pattern, evaluation)) == 2:
|
||||
return [float(sp[0]), float(sp[1])]
|
||||
|
||||
pattern = re.compile("([0-9]|10)/10")
|
||||
sp = re.findall(pattern, evaluation)
|
||||
if len(re.findall(pattern, evaluation)) == 2:
|
||||
return [float(sp[0]), float(sp[1])]
|
||||
|
||||
score_pair = evaluation.split("\n")[0]
|
||||
score_pair = score_pair.replace(",", " ")
|
||||
sp = score_pair.split(" ")
|
||||
if len(sp) == 2:
|
||||
return [float(sp[0]), float(sp[1])]
|
||||
else:
|
||||
raise Exception(f"Invalid score pair. Got {evaluation}.")
|
||||
except Exception:
|
||||
return [-1, -1]
|
||||
|
||||
|
||||
def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]) -> List[Dict]:
|
||||
"""
|
||||
Use GPT-4 to compare answers of two different models.
|
||||
|
||||
Args:
|
||||
answer1: answers of model 1.
|
||||
answer2: answers of model 2.
|
||||
prompt_dict: prompt for battle.
|
||||
|
||||
Returns:
|
||||
Evaluations of all comparison pairs.
|
||||
"""
|
||||
|
||||
assert len(answer1) == len(answer2)
|
||||
|
||||
total_len = len(answer1)
|
||||
question_idx_list = list(range(total_len))
|
||||
|
||||
print(f" Total number of answers: {len(answer1)}.")
|
||||
|
||||
evaluations = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
futures = []
|
||||
for i in question_idx_list:
|
||||
assert answer1[i]["id"] == answer2[i]["id"]
|
||||
answer_id = answer1[i]["id"]
|
||||
|
||||
ques = (
|
||||
answer1[i]["instruction"]
|
||||
if answer1[i]["input"] == ""
|
||||
else answer1[i]["instruction"] + " " + answer1[i]["input"]
|
||||
)
|
||||
answer1[i]["category"]
|
||||
ans1 = answer1[i]["output"]
|
||||
ans2 = answer2[i]["output"]
|
||||
|
||||
sys_prompt = prompt_dict["system_prompt"]
|
||||
prompt_template = prompt_dict["prompt_template"]
|
||||
prompt = prompt_template.format(
|
||||
question=ques,
|
||||
answer_1=ans1,
|
||||
answer_2=ans2,
|
||||
prompt=prompt_dict["prompt"],
|
||||
)
|
||||
|
||||
future = executor.submit(get_battle_result, sys_prompt, prompt, answer_id, 2048)
|
||||
futures.append(future)
|
||||
|
||||
for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
|
||||
evaluations.append(future.result())
|
||||
|
||||
evaluations.sort(key=lambda x: x["id"])
|
||||
|
||||
return evaluations
|
||||
|
||||
|
||||
def save_battle_results(evaluations: List[Dict], name1: str, name2: str, save_path: str) -> None:
|
||||
"""
|
||||
Save evaluation results (model 1 vs model 2) from GPT-4.
|
||||
|
||||
Args:
|
||||
evaluations: evaluation results from GPT-4.
|
||||
name1: model 1 's name.
|
||||
name2: model 2 's name.
|
||||
save_path: path to save battle results.
|
||||
"""
|
||||
|
||||
evaluation_file = deepcopy(evaluations)
|
||||
|
||||
ans1_score = 0
|
||||
ans2_score = 0
|
||||
better_count = 0
|
||||
worse_count = 0
|
||||
tie_count = 0
|
||||
invalid_count = 0
|
||||
|
||||
better_file = []
|
||||
worse_file = []
|
||||
tie_file = []
|
||||
invalid_file = []
|
||||
|
||||
for idx, evaluation in enumerate(evaluations):
|
||||
scores = parse_battle_score(evaluation["evaluation"])
|
||||
evaluation_file[idx]["score"] = scores
|
||||
|
||||
if scores[0] == -1 and scores[1] == -1:
|
||||
invalid_count += 1
|
||||
invalid_file.append(evaluation_file[idx])
|
||||
print(f'Invalid score pair: {evaluation_file[idx]["id"]}.')
|
||||
else:
|
||||
if scores[0] > scores[1]:
|
||||
worse_count += 1
|
||||
worse_file.append(evaluation_file[idx])
|
||||
elif scores[0] < scores[1]:
|
||||
better_count += 1
|
||||
better_file.append(evaluation_file[idx])
|
||||
else:
|
||||
tie_count += 1
|
||||
tie_file.append(evaluation_file[idx])
|
||||
ans1_score += scores[0]
|
||||
ans2_score += scores[1]
|
||||
|
||||
prefix = f"{name1}_vs_{name2}"
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
jdump(better_file, os.path.join(save_path, prefix, f"{name2}_better.json"))
|
||||
jdump(worse_file, os.path.join(save_path, prefix, f"{name2}_worse.json"))
|
||||
jdump(tie_file, os.path.join(save_path, prefix, f"{prefix}_tie.json"))
|
||||
jdump(invalid_file, os.path.join(save_path, prefix, f"{prefix}_invalid.json"))
|
||||
jdump(evaluation_file, os.path.join(save_path, prefix, f"{prefix}_evaluations.json"))
|
||||
|
||||
if os.path.exists(os.path.join(save_path, "battle_results.json")):
|
||||
results = jload(os.path.join(save_path, "battle_results.json"))
|
||||
else:
|
||||
results = {}
|
||||
|
||||
results[prefix] = {
|
||||
"model": [name1, name2],
|
||||
"better": better_count,
|
||||
"worse": worse_count,
|
||||
"tie": tie_count,
|
||||
"win_rate": better_count / (len(evaluations) - invalid_count),
|
||||
"score": [
|
||||
ans1_score / (len(evaluations) - invalid_count),
|
||||
ans2_score / (len(evaluations) - invalid_count),
|
||||
],
|
||||
}
|
||||
jdump(results, os.path.join(save_path, "battle_results.json"))
|
||||
|
||||
print(f"Total {invalid_count} invalid score pair(s).")
|
||||
print(f"Model {name2} has {better_count} better answer(s).")
|
||||
print(f"Model {name2} has {worse_count} worse answer(s).")
|
||||
print(f"{tie_count} answer(s) play(s) to a tie.")
|
||||
print(f"Win rate of model {name2}: {better_count/(len(evaluations)-invalid_count):.2f}")
|
||||
print(f"Model {name1} average score: {ans1_score/(len(evaluations)-invalid_count):.2f}")
|
||||
print(f"Model {name2} average score: {ans2_score/(len(evaluations)-invalid_count):.2f}")
|
||||
|
||||
|
||||
def reference_template(metric: str, language: str, reference: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Get prompt template for GPT evaluation with reference.
|
||||
|
||||
Different languages have different prompt templates.
|
||||
|
||||
Args:
|
||||
metric: metric used in GPT evaluation with reference.
|
||||
language: language for the template.
|
||||
reference: the instruction that contains target answer.
|
||||
|
||||
Returns:
|
||||
Prompt template for GPT evaluation with reference.
|
||||
"""
|
||||
|
||||
step_to_add = ref_step_template[language]
|
||||
|
||||
for_the_given_answer = (
|
||||
"{metric} (1-5) (directly give the score for the given answer):"
|
||||
if language == "en"
|
||||
else "{metric} (1-5) (直接对给定答案打分)"
|
||||
)
|
||||
|
||||
# adjective is used to describe the word "answer" in the prompt.
|
||||
adjective = "example" if language == "en" else "示例"
|
||||
answer_to_add = ref_answer_template_general[language]
|
||||
|
||||
# Only for correctness, we will provide a correct answer and so the adjective for "answer" will be "correct". The prompt words will be "a correct answer".
|
||||
# In other cases, the prompt words will be "an example answer with good quality" by default.
|
||||
if metric.lower() == "correctness":
|
||||
adjective = "correct" if language == "en" else "标准"
|
||||
answer_to_add = ref_answer_template_correctness[language]
|
||||
|
||||
answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"])
|
||||
step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format(
|
||||
metric=metric
|
||||
)
|
||||
|
||||
return answer_to_add + step_to_add
|
||||
|
||||
|
||||
def fill_in_message(role: str, content: str) -> Dict[str, str]:
|
||||
"""
|
||||
Generate one formatted message to send through chat completion.
|
||||
|
||||
Args:
|
||||
role: the role of the author of this message.
|
||||
content: the contents of the message.
|
||||
|
||||
Returns:
|
||||
One message to send through chat completion.
|
||||
"""
|
||||
|
||||
return {"role": role, "content": content}
|
||||
|
||||
|
||||
def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens: int = 1, turns=2) -> Dict[str, Any]:
|
||||
"""
|
||||
Do multi-turn chat completion.
|
||||
|
||||
When turns == 1, it is a one-turn conversation for normal GPT evaluation.
|
||||
When turns == 2, it is a two-turn conversation which is used for GPT evaluation with reference answers.
|
||||
|
||||
Args:
|
||||
user_messages: messages user wants to send.
|
||||
model: the model used to evaluate answers.
|
||||
max_tokens: the maximum number of tokens to generate in the chat completion.
|
||||
turns: the number of turns for conversation.
|
||||
|
||||
Returns:
|
||||
Last turn's response.
|
||||
"""
|
||||
|
||||
if len(user_messages) != turns:
|
||||
raise Exception("The length of user messages should be equal to the turn number!")
|
||||
|
||||
assistant_responses = []
|
||||
|
||||
for i in range(turns):
|
||||
messages_to_send = []
|
||||
|
||||
for j in range(i):
|
||||
messages_to_send.append(fill_in_message("user", user_messages[j]))
|
||||
messages_to_send.append(
|
||||
fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])
|
||||
)
|
||||
|
||||
# Length of user messages == Length of assistant messages + 1
|
||||
# Because we always expect the api to response
|
||||
messages_to_send.append(fill_in_message("user", user_messages[i]))
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=messages_to_send,
|
||||
temperature=0,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Avoid exceeding rate limits.
|
||||
# You can comment this line if your request doesn't contain many tokens.
|
||||
time.sleep(1)
|
||||
|
||||
assistant_responses.append(response)
|
||||
|
||||
return assistant_responses[-1]
|
||||
|
||||
|
||||
def get_gpt_evaluation_without_logprobs(
|
||||
prompt: Dict[str, Any],
|
||||
inst: Dict[str, Any],
|
||||
metrics: List[str],
|
||||
language: str,
|
||||
reference: Dict[str, Any] = None,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
max_tokens: int = 2048,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer.
|
||||
|
||||
Temprature is set to 0 to make the model more deterministic.
|
||||
|
||||
Args:
|
||||
prompt: a dictionary including prompt template, CoT and metrics.
|
||||
inst: the instruction that is needed to be evaluated.
|
||||
metrics: the metrics for evaluation.
|
||||
language: language used to change the CoT(add one more step about comparing the given answer and reference) if reference is not None.
|
||||
reference: the reference answer.
|
||||
model: the model used to evaluate answers.
|
||||
max_tokens: the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
Returns:
|
||||
An evaluation of one answer.
|
||||
"""
|
||||
|
||||
MAX_API_RETRY = 3
|
||||
|
||||
question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
|
||||
answer = inst["output"]
|
||||
inst["evaluation"] = {}
|
||||
|
||||
for metric in metrics:
|
||||
if prompt["metrics"].get(metric, None) is None:
|
||||
raise Exception(
|
||||
f"Unsupported metric {metric} for category {inst['category']}! You should add this metric in the prompt file!"
|
||||
)
|
||||
for i in range(MAX_API_RETRY):
|
||||
try:
|
||||
prompt_reference = "" if reference is None else reference_template(metric, language, reference)
|
||||
|
||||
prompt_1st_round = prompt["prompt"].format(
|
||||
question=question,
|
||||
answer=answer,
|
||||
metric=prompt["metrics"][metric],
|
||||
steps=prompt["CoT"][metric],
|
||||
)
|
||||
|
||||
if prompt_reference and (reference["target"] or reference["output"]):
|
||||
# Do a 2-round conversation
|
||||
response = multiturn_chat_completion(
|
||||
[prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2
|
||||
)
|
||||
else:
|
||||
response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1)
|
||||
|
||||
inst["evaluation"][metric] = {
|
||||
"response": response["choices"][0]["message"]["content"],
|
||||
"logprobs": None,
|
||||
}
|
||||
|
||||
# Prevent exceeding rate limits because we have multiple workers.
|
||||
# But this will slow down the evaluation process.
|
||||
# You can comment this line if your request doesn't contain many tokens.
|
||||
time.sleep(len(metrics) * 0.5)
|
||||
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
time.sleep(1)
|
||||
if metric not in inst["evaluation"]:
|
||||
print(f"Evaluation {inst['id']} for metric {metric} failed after {MAX_API_RETRY} retries.")
|
||||
inst["evaluation"][metric] = {}
|
||||
return inst
|
||||
|
||||
|
||||
def get_gpt_evaluation_with_logprobs(
|
||||
prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Use completion model(text-davinci-003) to evaluate one model answer.
|
||||
Only completion models can return log probabilities.
|
||||
|
||||
Temprature is set to 0 to make the model more deterministic.
|
||||
|
||||
Args:
|
||||
prompt: a dictionary including prompt template, CoT and metrics.
|
||||
inst: the instruction that is needed to be evaluated.
|
||||
metrics: the metrics for evaluation.
|
||||
max_tokens: the maximum number of tokens to generate in the completion.
|
||||
|
||||
Returns:
|
||||
An evaluation of one answer.
|
||||
"""
|
||||
|
||||
MAX_API_RETRY = 3
|
||||
|
||||
question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
|
||||
answer = inst["output"]
|
||||
inst["evaluation"] = {}
|
||||
|
||||
for metric in metrics:
|
||||
if prompt["metrics"].get(metric, None) is None:
|
||||
raise Exception(
|
||||
f"Unsupported metric {metric} for category {inst['category']}! You should add this metric in the prompt file!"
|
||||
)
|
||||
for i in range(MAX_API_RETRY):
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
model="text-davinci-003",
|
||||
prompt=prompt["prompt"].format(
|
||||
question=question,
|
||||
answer=answer,
|
||||
metric=prompt["metrics"][metric],
|
||||
steps=prompt["CoT"][metric],
|
||||
),
|
||||
logprobs=5,
|
||||
temperature=0,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
inst["evaluation"][metric] = {
|
||||
"response": response["choices"][0]["text"],
|
||||
"logprobs": response["choices"][0]["logprobs"]["top_logprobs"],
|
||||
}
|
||||
|
||||
# Prevent exceeding rate limits because we have multiple workers.
|
||||
# But this will slow down the evaluation process.
|
||||
# You can comment this line if your request doesn't contain many tokens.
|
||||
time.sleep(len(metrics) * 0.5)
|
||||
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
time.sleep(1)
|
||||
if metric not in inst["evaluation"]:
|
||||
print(f"Evaluation {inst['id']} for metric {metric} failed after {MAX_API_RETRY} retries.")
|
||||
inst["evaluation"][metric] = {}
|
||||
return inst
|
||||
|
||||
|
||||
def evaluate(
|
||||
answers: List[Dict],
|
||||
prompt: Dict[str, Any],
|
||||
metrics: List[str],
|
||||
category: str,
|
||||
save_path: str,
|
||||
model_name: str,
|
||||
model: str,
|
||||
language: str,
|
||||
references: List[Dict] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Use GPT models to evaluate model answers and save evaluation results.
|
||||
|
||||
Args:
|
||||
answers: model answers.
|
||||
prompt: prompt for GPT evaluation.
|
||||
metrics: metrics for GPT evaluation.
|
||||
category: the category of the model answers for evaluation.
|
||||
model: the specific GPT model used to evaluate answers.
|
||||
language: language used in GPT evaluation
|
||||
references: references for GPT evaluation
|
||||
|
||||
Returns:
|
||||
Evaluations of the given answers.
|
||||
"""
|
||||
|
||||
print(f"The number of instances of category {category}'s is {len(answers)}.")
|
||||
|
||||
evaluations = []
|
||||
|
||||
metrics_str = ", ".join(x for x in metrics)
|
||||
print(f"Category {category}'s metrics are {metrics_str}.")
|
||||
|
||||
gpt_base_save_path = os.path.join(save_path, "gpt_evaluate", "gpt_evaluate_results")
|
||||
gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
|
||||
category_file = os.path.join(gpt_evaluation_results_save_path, model_name, f"{category}_evaluation_results.json")
|
||||
|
||||
if os.path.exists(category_file):
|
||||
print(f"Evaluation results for category {category}, model {model_name} already exists.")
|
||||
print("Skip evaluating.")
|
||||
|
||||
evaluations = jload(category_file)
|
||||
|
||||
retry = []
|
||||
evaluations_copy = deepcopy(evaluations)
|
||||
|
||||
success = []
|
||||
for idx, e in enumerate(evaluations_copy):
|
||||
keys = list(e["evaluation"].keys())
|
||||
for key in keys:
|
||||
if e["evaluation"][key] == {}:
|
||||
retry.append(e["id"])
|
||||
print(f"Re-evaluate id {e['id']} now.")
|
||||
break
|
||||
if e["id"] not in retry:
|
||||
success.append(e)
|
||||
|
||||
if len(retry) == 0:
|
||||
evaluations.sort(key=lambda x: x["id"])
|
||||
print(f"{category} done.")
|
||||
return evaluations
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
futures = []
|
||||
for idx, inst in enumerate(answers):
|
||||
if not inst["id"] in retry:
|
||||
continue
|
||||
# Completion models can return log probabilities.
|
||||
if model == "text-davinci-003":
|
||||
future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
|
||||
else:
|
||||
future = executor.submit(
|
||||
get_gpt_evaluation_without_logprobs,
|
||||
prompt,
|
||||
inst,
|
||||
metrics,
|
||||
language,
|
||||
reference=None if references is None else references[idx],
|
||||
model=model,
|
||||
max_tokens=1,
|
||||
)
|
||||
|
||||
futures.append(future)
|
||||
|
||||
for future in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
desc=f"{category}: ",
|
||||
total=len(futures),
|
||||
):
|
||||
success.append(future.result())
|
||||
|
||||
success.sort(key=lambda x: x["id"])
|
||||
|
||||
print(f"Saving evaluation results for category {category}, model {model_name}.")
|
||||
|
||||
jdump(success, category_file)
|
||||
|
||||
return success
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
futures = []
|
||||
for idx, inst in enumerate(answers):
|
||||
# Completion models can return log probabilities.
|
||||
if model == "text-davinci-003":
|
||||
future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
|
||||
else:
|
||||
future = executor.submit(
|
||||
get_gpt_evaluation_without_logprobs,
|
||||
prompt,
|
||||
inst,
|
||||
metrics,
|
||||
language,
|
||||
reference=None if references is None else references[idx],
|
||||
model=model,
|
||||
max_tokens=1,
|
||||
)
|
||||
|
||||
futures.append(future)
|
||||
|
||||
for future in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
desc=f"{category}: ",
|
||||
total=len(futures),
|
||||
):
|
||||
evaluations.append(future.result())
|
||||
|
||||
evaluations.sort(key=lambda x: x["id"])
|
||||
|
||||
print(f"{category} done.")
|
||||
|
||||
print(f"Saving evaluation results for category {category}, model {model_name}.")
|
||||
|
||||
jdump(evaluations, category_file)
|
||||
|
||||
return evaluations
|
||||
|
||||
|
||||
def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Calculate the score according to log probabilities returned by text-davinci-003.
|
||||
|
||||
Calculation formula:
|
||||
score = sum(score_i * exp(value)) where score_i is the score which corresponds to the key(predicted token) and value is its log probability.
|
||||
|
||||
Ref: https://arxiv.org/abs/2303.16634
|
||||
This paper proposes NLG evaluation methods using text-davinci-003(log probabilities returned by completion models) and GPT-4(probabilities obtained by sampling).
|
||||
|
||||
Args:
|
||||
logprobs: logprobs returned by openai.Completion.
|
||||
|
||||
Returns:
|
||||
The score of one answer.
|
||||
"""
|
||||
|
||||
# GPT-3.5 only returns score of 1 to 5.
|
||||
prob = np.zeros(5)
|
||||
|
||||
for key, value in logprobs.items():
|
||||
# Sometimes the key will be one byte of a unicode character which takes the form of "bytes:\\xe7".
|
||||
# It is meaningless and thus we don't calculate probability.
|
||||
if "bytes" in key:
|
||||
continue
|
||||
# results[0] is the score which corresponds to the key(predicted token).
|
||||
# For example, key "5" corresponds to score 5.
|
||||
results = re.findall(r"\d", key)
|
||||
if len(results) == 1:
|
||||
prob[int(results[0]) - 1] = prob[int(results[0]) - 1] + np.exp(value)
|
||||
|
||||
score = np.dot(np.arange(1, 6), prob)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) -> int:
|
||||
"""
|
||||
Calculate the score from the response returned by gpt-3.5-turbo or gpt-4.
|
||||
Different from text-davinci-003, this fuction directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4.
|
||||
Although text-davinci-003 can return log probabilities, it costs ten times as much as gpt-3.5-turbo.
|
||||
|
||||
Args:
|
||||
response: logprobs returned by openai.Completion.
|
||||
evaluation: the evaluation corresponds to the question.
|
||||
|
||||
Returns:
|
||||
The score of one answer.
|
||||
"""
|
||||
|
||||
try:
|
||||
results = re.findall(r"\d", response)
|
||||
if len(results) == 1:
|
||||
return int(results[0])
|
||||
else:
|
||||
raise Exception(f"Invalid score pair. Got {evaluation}.")
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def save_gpt_evaluation_results(
|
||||
model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Save evaluation results for different categories for one model.
|
||||
|
||||
Args:
|
||||
model_name: name of the model for saving evaluation results.
|
||||
gpt_evaluation_results: evaluations results for all of the model answers.
|
||||
save_path: path to save GPT evaluation statistics.
|
||||
"""
|
||||
|
||||
all_evaluations = []
|
||||
for category, evaluations in gpt_evaluation_results.items():
|
||||
jdump(evaluations, os.path.join(save_path, model_name, f"{category}_evaluation_results.json"))
|
||||
all_evaluations.extend(evaluations)
|
||||
|
||||
jdump(all_evaluations, os.path.join(save_path, f"{model_name}_evaluation_results.json"))
|
||||
|
||||
return all_evaluations
|
||||
|
||||
|
||||
def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], save_path: str) -> None:
|
||||
"""
|
||||
Generate statistics for one model.
|
||||
|
||||
Args:
|
||||
model_name: name of the model for saving statistics.
|
||||
evaluations: evaluations for all of the model answers.
|
||||
save_path: path to save GPT evaluation statistics.
|
||||
"""
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
data_per_category = {}
|
||||
for evaluation in evaluations:
|
||||
category = evaluation["category"]
|
||||
if evaluation["category"] in data_per_category.keys():
|
||||
data_per_category[category].append(evaluation)
|
||||
else:
|
||||
data_per_category[category] = [evaluation]
|
||||
|
||||
all_statistics = {}
|
||||
for category, data in data_per_category.items():
|
||||
metrics = data[0]["evaluation"].keys()
|
||||
scores = {metric: [] for metric in metrics}
|
||||
for evaluation in data:
|
||||
for metric in metrics:
|
||||
if evaluation["evaluation"][metric] == {}:
|
||||
# This means after 3 retries, the server still returns an error and we set the score to 0.
|
||||
scores[metric].append(0)
|
||||
elif evaluation["evaluation"][metric]["logprobs"] is not None:
|
||||
scores[metric].append(
|
||||
calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])
|
||||
)
|
||||
else:
|
||||
scores[metric].append(
|
||||
calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)
|
||||
)
|
||||
|
||||
statistics = {}
|
||||
for metric in metrics:
|
||||
arg_sort = np.argsort(scores[metric])
|
||||
statistics[metric] = {}
|
||||
statistics[metric]["avg_score"] = sum(scores[metric]) / len(data)
|
||||
statistics[metric]["best_3"] = {data[i]["id"]: scores[metric][i] for i in arg_sort[-3:][::-1]}
|
||||
statistics[metric]["worst_3"] = {data[i]["id"]: scores[metric][i] for i in arg_sort[:3]}
|
||||
|
||||
all_statistics[category] = statistics
|
||||
|
||||
jdump(
|
||||
all_statistics,
|
||||
os.path.join(save_path, f"{model_name}_evaluation_statistics.json"),
|
||||
)
|
||||
|
||||
|
||||
def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> None:
|
||||
"""
|
||||
Analyze and visualize all GPT evaluation statistics in the given directory.
|
||||
|
||||
Args:
|
||||
statistics_path: path to all the models' statistics.
|
||||
save_path: path to save table and visualization results.
|
||||
"""
|
||||
|
||||
if not os.path.exists(statistics_path):
|
||||
raise Exception(f'The given directory "{statistics_path}" doesn\'t exist! No statistics found!')
|
||||
|
||||
all_statistics = {}
|
||||
|
||||
for file_name in os.listdir(statistics_path):
|
||||
if file_name.endswith("_evaluation_statistics.json"):
|
||||
model_name = file_name.split("_evaluation_statistics.json")[0]
|
||||
all_statistics[model_name] = jload(os.path.join(statistics_path, file_name))
|
||||
|
||||
if len(list(all_statistics.keys())) == 0:
|
||||
raise Exception(f'There are no statistics in the given directory "{statistics_path}"!')
|
||||
|
||||
frame_all = {
|
||||
"model": [],
|
||||
"category": [],
|
||||
"metric": [],
|
||||
"avg_score": [],
|
||||
"best_3": [],
|
||||
"worst_3": [],
|
||||
}
|
||||
frame_per_category = {}
|
||||
for model_name, model_statistics in all_statistics.items():
|
||||
for category, category_statistics in model_statistics.items():
|
||||
if frame_per_category.get(category) is None:
|
||||
frame_per_category[category] = {
|
||||
"model": [],
|
||||
"metric": [],
|
||||
"avg_score": [],
|
||||
"best_3": [],
|
||||
"worst_3": [],
|
||||
}
|
||||
|
||||
for metric, metric_statistics in category_statistics.items():
|
||||
frame_all["model"].append(model_name)
|
||||
frame_all["category"].append(category)
|
||||
frame_all["metric"].append(metric)
|
||||
frame_all["avg_score"].append(metric_statistics["avg_score"])
|
||||
frame_all["best_3"].append(metric_statistics["best_3"])
|
||||
frame_all["worst_3"].append(metric_statistics["worst_3"])
|
||||
|
||||
frame_per_category[category]["model"].append(model_name)
|
||||
frame_per_category[category]["metric"].append(metric)
|
||||
frame_per_category[category]["avg_score"].append(metric_statistics["avg_score"])
|
||||
frame_per_category[category]["best_3"].append(metric_statistics["best_3"])
|
||||
frame_per_category[category]["worst_3"].append(metric_statistics["worst_3"])
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
frame_all = pd.DataFrame(frame_all)
|
||||
frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv"))
|
||||
|
||||
for category in tqdm.tqdm(
|
||||
frame_per_category.keys(),
|
||||
desc=f"GPT evaluation: ",
|
||||
total=len(frame_per_category.keys()),
|
||||
):
|
||||
data = pd.DataFrame(frame_per_category[category])
|
||||
|
||||
sns.set()
|
||||
fig = plt.figure(figsize=(16, 10))
|
||||
plt.ylim((0, 5))
|
||||
|
||||
fig = sns.barplot(x="metric", y="avg_score", hue="model", data=data, dodge=True)
|
||||
fig.set_title(f"Comparison between Different Models for Category {category.title()}")
|
||||
plt.xlabel("Evaluation Metric")
|
||||
plt.ylabel("Average Score")
|
||||
|
||||
figure = fig.get_figure()
|
||||
figure.savefig(os.path.join(save_path, f"{category}.png"), dpi=400)
|
||||
|
||||
plt.close()
|
@@ -0,0 +1,8 @@
|
||||
def get_data_per_category(data, categories):
|
||||
data_per_category = {category: [] for category in categories}
|
||||
for item in data:
|
||||
category = item["category"]
|
||||
if category in categories:
|
||||
data_per_category[category].append(item)
|
||||
|
||||
return data_per_category
|
@@ -0,0 +1,5 @@
|
||||
from .base import BaseModel
|
||||
from .chatglm import ChatGLM2Model, ChatGLMModel
|
||||
from .huggingface import HuggingFaceCausalLM, HuggingFaceModel
|
||||
|
||||
__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"]
|
78
applications/ColossalEval/colossal_eval/models/base.py
Normal file
78
applications/ColossalEval/colossal_eval/models/base.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from abc import abstractclassmethod
|
||||
from typing import Dict, List
|
||||
|
||||
from colossal_eval.utils import Conversation, prompt_templates
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
class BaseModel:
|
||||
"""
|
||||
Base class for model wrapper.
|
||||
|
||||
Args:
|
||||
path: The path to the model.
|
||||
model_max_length: The maximum sequence length of the model.
|
||||
prompt_template: The model's prompt template.
|
||||
batch_size: Batch size for inference.
|
||||
logger: Logger for the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
model_max_length: int = 2048,
|
||||
prompt_template: Conversation = None,
|
||||
batch_size: int = 1,
|
||||
logger: DistributedLogger = None,
|
||||
):
|
||||
self.path = path
|
||||
self.model_max_length = model_max_length
|
||||
|
||||
if prompt_template:
|
||||
self.prompt_template = prompt_template
|
||||
else:
|
||||
self.prompt_template = prompt_templates["plain"]
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.logger = logger
|
||||
|
||||
@abstractclassmethod
|
||||
def inference(self, data: List[Dict]) -> None:
|
||||
"""
|
||||
Infer the given data.
|
||||
This function will call self.generate() to get model outputs and also self.model(input) to get logits.
|
||||
|
||||
Args:
|
||||
data: The data for inference.
|
||||
"""
|
||||
|
||||
@abstractclassmethod
|
||||
def generate(self, inputs: List[str], max_new_tokens: int) -> List[str]:
|
||||
"""
|
||||
Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs: A list of strings.
|
||||
max_new_tokens: The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
A list of generated strings.
|
||||
"""
|
||||
|
||||
@abstractclassmethod
|
||||
def get_loss(self, batch: List[str], batch_target: List[str]) -> List[float]:
|
||||
"""
|
||||
Get loss given batch and batch with target.
|
||||
Use their length difference after tokenization to mask the loss and only compute loss at target tokens.
|
||||
|
||||
Args:
|
||||
batch: batch prompt without target answer.
|
||||
batch_target: batch prompt with target answer.
|
||||
|
||||
Returns:
|
||||
A list of loss.
|
||||
"""
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
303
applications/ColossalEval/colossal_eval/models/chatglm.py
Normal file
303
applications/ColossalEval/colossal_eval/models/chatglm.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from .huggingface import HuggingFaceModel
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
class ChatGLMModel(HuggingFaceModel):
|
||||
def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
|
||||
truncated_inputs = copy.deepcopy(inputs)
|
||||
# Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187
|
||||
for i, input in enumerate(inputs):
|
||||
a_ids = self.tokenizer.encode(text=input, truncation=False, add_special_tokens=False)
|
||||
|
||||
if len(a_ids) > self.model_max_length - max_new_tokens:
|
||||
half = (self.model_max_length - max_new_tokens) // 2
|
||||
prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode(
|
||||
a_ids[-half:], skip_special_tokens=True
|
||||
)
|
||||
truncated_inputs[i] = prompt
|
||||
|
||||
return truncated_inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def get_loss(
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Calculate loss only on target tokens.
|
||||
|
||||
Args:
|
||||
batch: A batch of prompt without target answer.
|
||||
batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
|
||||
|
||||
Returns:
|
||||
Loss.
|
||||
|
||||
"""
|
||||
|
||||
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
|
||||
# We don't need to generate new tokens.
|
||||
# Target answer's length is usually << model_max_length, but we still call it in case.
|
||||
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
|
||||
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
|
||||
|
||||
# Get the number of target answers for different questions
|
||||
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
|
||||
|
||||
labels_list = []
|
||||
input_ids_list = []
|
||||
|
||||
for input, targets in zip(batch_prompt, batch_target):
|
||||
for target in targets:
|
||||
# Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187
|
||||
# If there is no history, the prompt is just the query.
|
||||
# We don't need to override self.generate() in ChatGLM-6B but need to override it in ChatGLM2-6B.
|
||||
# See https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1276
|
||||
target_tokenized = self.tokenizer.encode(text=target, add_special_tokens=False)
|
||||
|
||||
# Get prompt with length model_max_length - len(target_tokenized).
|
||||
# Reserve some space for target answer tokens using max_new_tokens.
|
||||
# This will generate the correct start_idx and end_idx.
|
||||
max_new_tokens = len(target_tokenized)
|
||||
|
||||
# Here 3 tokens are reserved for [gmask_id, bos_token, eos_id]. So we reserve max_new_tokens + 3 tokens.
|
||||
# See https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py#L323
|
||||
prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens + 3)[0]
|
||||
input_tokenized = self.tokenizer.encode(prompt_with_correct_length, add_special_tokens=False)
|
||||
|
||||
input_ids = self.tokenizer.build_inputs_with_special_tokens(input_tokenized, target_tokenized)
|
||||
|
||||
context_length = input_ids.index(self.tokenizer.bos_token_id)
|
||||
context_length - 1
|
||||
|
||||
target_ids = [IGNORE_INDEX] * len(input_ids)
|
||||
|
||||
# -1 is for eos_token, we don't want to calculate loss on eos token.
|
||||
target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1]
|
||||
|
||||
input_ids_list.append(torch.LongTensor(input_ids))
|
||||
labels_list.append(torch.LongTensor(target_ids))
|
||||
|
||||
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
|
||||
# We will generate new batches.
|
||||
losses = []
|
||||
target_token_nums = []
|
||||
|
||||
batched_input_ids = [
|
||||
input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
|
||||
]
|
||||
batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
|
||||
|
||||
for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
|
||||
losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
|
||||
losses.extend(losses_per_batch)
|
||||
target_token_nums.extend(target_token_num_per_batch)
|
||||
|
||||
start_indice = 0
|
||||
losses_per_sample = []
|
||||
|
||||
target_token_nums_per_sample = []
|
||||
for length in batch_target_nums:
|
||||
losses_per_sample.append(losses[start_indice : start_indice + length])
|
||||
target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
|
||||
start_indice += length
|
||||
|
||||
return losses_per_sample, target_token_nums_per_sample, None
|
||||
|
||||
def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> List[float]:
|
||||
"""
|
||||
Calculate loss only on target tokens.
|
||||
Hugging Face generate() function can't return per sample loss.
|
||||
It will only return the mean of the loss in a batch.
|
||||
In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss.
|
||||
|
||||
Args:
|
||||
input_ids_list: A batch of input token ids.
|
||||
labels: A batch of labels.
|
||||
|
||||
Returns:
|
||||
A list of loss.
|
||||
|
||||
"""
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
).to(torch.cuda.current_device())
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
||||
torch.cuda.current_device()
|
||||
)
|
||||
|
||||
outputs = self.model(input_ids)[0]
|
||||
|
||||
shift_logits = outputs[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
|
||||
|
||||
lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy()
|
||||
|
||||
loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
|
||||
return loss_sum.tolist(), lens.tolist()
|
||||
|
||||
|
||||
class ChatGLM2Model(ChatGLMModel):
|
||||
def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
|
||||
truncated_inputs = copy.deepcopy(inputs)
|
||||
# Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180
|
||||
for i, input in enumerate(inputs):
|
||||
a_ids = self.tokenizer.encode(text=input, add_special_tokens=True, truncation=False)
|
||||
|
||||
if len(a_ids) > self.model_max_length - max_new_tokens:
|
||||
half = (self.model_max_length - max_new_tokens) // 2
|
||||
prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode(
|
||||
a_ids[-half:], skip_special_tokens=True
|
||||
)
|
||||
truncated_inputs[i] = prompt
|
||||
|
||||
return truncated_inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
|
||||
"""Generate results given a list of inputs and get logits of the first new token over choices.
|
||||
|
||||
Args:
|
||||
inputs: A list of strings.
|
||||
max_new_tokens: Max new tokens for generation.
|
||||
kwargs: Key arguments for generation
|
||||
|
||||
Returns:
|
||||
A list of generated strings and logits over choices.
|
||||
|
||||
Note:
|
||||
Currently the function only returns the logits of the first new token.
|
||||
It is used for single choice question.
|
||||
For multiple choices question, please avoid using the loss over choices.
|
||||
You should set argument choices as None in self.inference().
|
||||
|
||||
"""
|
||||
# Follow the process of model.chat() method in modeling_chatglm2.py
|
||||
# See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1020
|
||||
# See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1001
|
||||
|
||||
query = []
|
||||
for input in inputs:
|
||||
prompt = self.tokenizer.build_prompt(input, None)
|
||||
query.append(prompt)
|
||||
|
||||
truncated_query = self._get_truncated_prompts(query, max_new_tokens)
|
||||
|
||||
encoded_inputs = self.tokenizer(
|
||||
truncated_query,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
max_length=self.model_max_length - max_new_tokens,
|
||||
).to(torch.cuda.current_device())
|
||||
|
||||
# Set output_scores=True to get prediction scores.
|
||||
outputs = self.model.generate(
|
||||
**encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs
|
||||
)
|
||||
|
||||
# We only need to decode predicted tokens.
|
||||
sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :]
|
||||
|
||||
scores = []
|
||||
if self.indices_for_choices:
|
||||
# If the question is a single-choice question, we will return the scores of specific indices for first predicted token.
|
||||
# The indices are the tokenization results of the options for the single-choice question.
|
||||
# For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D.
|
||||
for option_indices in self.indices_for_choices:
|
||||
scores.append(outputs.scores[0][:, option_indices].detach().cpu())
|
||||
|
||||
scores = torch.max(torch.stack(scores), dim=0)[0]
|
||||
|
||||
decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
||||
|
||||
return decoded_sequences, scores
|
||||
|
||||
@torch.no_grad()
|
||||
def get_loss(
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Calculate loss only on target tokens.
|
||||
|
||||
Args:
|
||||
batch: A batch of prompt without target answer.
|
||||
batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
|
||||
|
||||
Returns:
|
||||
Loss.
|
||||
|
||||
"""
|
||||
|
||||
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
|
||||
# We don't need to generate new tokens.
|
||||
# Target answer's length is usually << model_max_length, but we still call it in case.
|
||||
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
|
||||
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
|
||||
|
||||
# Get the number of target answers for different questions
|
||||
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
|
||||
|
||||
labels_list = []
|
||||
input_ids_list = []
|
||||
|
||||
for input, targets in zip(batch_prompt, batch_target):
|
||||
for target in targets:
|
||||
# Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180
|
||||
prompt = self.tokenizer.build_prompt(input, None)
|
||||
|
||||
target_tokenized = self.tokenizer.encode(
|
||||
text=target, add_special_tokens=False, truncation=True, max_length=self.model_max_length
|
||||
)
|
||||
|
||||
max_new_tokens = len(target_tokenized)
|
||||
prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0]
|
||||
input_tokenized = self.tokenizer.encode(
|
||||
prompt_with_correct_length,
|
||||
add_special_tokens=True,
|
||||
truncation=True,
|
||||
max_length=self.model_max_length,
|
||||
)
|
||||
|
||||
input_ids = input_tokenized + target_tokenized + [self.tokenizer.eos_token_id]
|
||||
target_ids = [IGNORE_INDEX] * len(input_ids)
|
||||
|
||||
# -1 is for "eos"
|
||||
target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1]
|
||||
|
||||
input_ids_list.append(torch.LongTensor(input_ids))
|
||||
labels_list.append(torch.LongTensor(target_ids))
|
||||
|
||||
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
|
||||
# We will generate new batches.
|
||||
losses = []
|
||||
target_token_nums = []
|
||||
|
||||
batched_input_ids = [
|
||||
input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
|
||||
]
|
||||
batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
|
||||
|
||||
for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
|
||||
losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
|
||||
losses.extend(losses_per_batch)
|
||||
target_token_nums.extend(target_token_num_per_batch)
|
||||
|
||||
start_indice = 0
|
||||
losses_per_sample = []
|
||||
|
||||
target_token_nums_per_sample = []
|
||||
for length in batch_target_nums:
|
||||
losses_per_sample.append(losses[start_indice : start_indice + length])
|
||||
target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
|
||||
start_indice += length
|
||||
|
||||
return losses_per_sample, target_token_nums_per_sample, None
|
561
applications/ColossalEval/colossal_eval/models/huggingface.py
Normal file
561
applications/ColossalEval/colossal_eval/models/huggingface.py
Normal file
@@ -0,0 +1,561 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
|
||||
from peft import PeftModel
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import BaseModel
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
class HuggingFaceModel(BaseModel):
|
||||
"""
|
||||
Model wrapper around HuggingFace AutoModel models.
|
||||
|
||||
Args:
|
||||
path: The path to a HuggingFace model.
|
||||
model_max_length: The maximum sequence length of the model.
|
||||
tokenizer_path: The path to the tokenizer.
|
||||
tokenizer_kwargs: Keyword arguments for the tokenizer.
|
||||
peft_path: The name or path to the HuggingFace's PEFT model.
|
||||
model_kwargs: Keyword arguments for the model.
|
||||
prompt_template: The model's prompt template.
|
||||
batch_size: Batch size for inference.
|
||||
logger: Logger for the model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
model_max_length: int = 2048,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
tokenizer_kwargs: dict = dict(),
|
||||
peft_path: Optional[str] = None,
|
||||
model_kwargs: Dict = None,
|
||||
prompt_template: Conversation = None,
|
||||
batch_size: int = 1,
|
||||
logger: DistributedLogger = None,
|
||||
):
|
||||
super().__init__(
|
||||
path=path,
|
||||
model_max_length=model_max_length,
|
||||
prompt_template=prompt_template,
|
||||
batch_size=batch_size,
|
||||
logger=logger,
|
||||
)
|
||||
self._load_tokenizer(path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs)
|
||||
|
||||
self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path)
|
||||
|
||||
def _get_choices_indices(self, language: str):
|
||||
"""
|
||||
Get indices for each choice
|
||||
|
||||
Some tokenizer will insert BOS if you don't specify add_special_tokens=False such as Llama-2.
|
||||
The indices for choices may be different given the context. For example, for Llama-2 tokenizer, for Chinese context like "答案:{choice}", indices for choices A, B, C and D are 29909, 29933, 29907 and 29928, for English context like "Answer: {choice}", indices for choices A, B, C and D are 319, 350, 315 and 360.
|
||||
print(self.tokenizer("答案:A")) to see
|
||||
print(self.tokenizer("Answer: A")) to see
|
||||
|
||||
"""
|
||||
|
||||
# A trick for get "all" tokens ids related to given choices.
|
||||
self.indices_for_choices = [[] for _ in range(2)]
|
||||
for choice in self.choices:
|
||||
self.indices_for_choices[0].append(
|
||||
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
|
||||
)
|
||||
self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1])
|
||||
|
||||
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
|
||||
"""
|
||||
Load tokenizer.
|
||||
|
||||
Args:
|
||||
path: The path to the model. Usually it also serves as the path to the tokenizer.
|
||||
tokenizer_path: The path to the tokenzier.
|
||||
tokenizer_kwargs: Keyword arguments for the tokenizer.
|
||||
|
||||
"""
|
||||
|
||||
if self.batch_size > 1:
|
||||
tokenizer_kwargs.update({"padding_side": "left"})
|
||||
tokenizer_kwargs.update({"truncation_side": "left"})
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path if tokenizer_path else path, **tokenizer_kwargs)
|
||||
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.")
|
||||
if self.tokenizer.eos_token:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
elif self.tokenizer.eod_id:
|
||||
# Qwen has an eod token "<|endoftext|>".
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eod_id
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
|
||||
"""
|
||||
Load model.
|
||||
|
||||
Args:
|
||||
path: The path to the model.
|
||||
model_kwargs: Keyword arguments for the model.
|
||||
peft_path: The path to the peft model.
|
||||
|
||||
"""
|
||||
|
||||
if "torch_dtype" in model_kwargs:
|
||||
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
|
||||
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
self.model.eval()
|
||||
|
||||
def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> Tuple[List]:
|
||||
"""
|
||||
Calculate loss only on target tokens.
|
||||
Hugging Face generate() function can't return per sample loss.
|
||||
It will only return the mean of the loss in a batch.
|
||||
In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss.
|
||||
|
||||
Args:
|
||||
input_ids_list: A batch of input token ids.
|
||||
labels: A batch of labels.
|
||||
|
||||
Returns:
|
||||
A list of loss.
|
||||
|
||||
"""
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
).to(torch.cuda.current_device())
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
||||
torch.cuda.current_device()
|
||||
)
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device())
|
||||
|
||||
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
shift_logits = outputs[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
|
||||
|
||||
lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy()
|
||||
|
||||
loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
|
||||
return loss_sum.tolist(), lens.tolist()
|
||||
|
||||
def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
|
||||
"""
|
||||
Truncate the input sequence to fit model_max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
|
||||
https://github.com/THUDM/LongBench/blob/main/pred.py#L16
|
||||
|
||||
Args:
|
||||
inputs: A batch of input prompts.
|
||||
max_new_tokens: Max new tokens for model to generate.
|
||||
|
||||
Returns:
|
||||
Truncated prompts.
|
||||
|
||||
"""
|
||||
|
||||
truncated_inputs = copy.deepcopy(inputs)
|
||||
for i, input in enumerate(inputs):
|
||||
tokenized_prompt = self.tokenizer(input, truncation=False, return_tensors="pt").input_ids[0]
|
||||
if len(tokenized_prompt) > self.model_max_length - max_new_tokens:
|
||||
half = (self.model_max_length - max_new_tokens) // 2
|
||||
prompt = self.tokenizer.decode(
|
||||
tokenized_prompt[:half], skip_special_tokens=True
|
||||
) + self.tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
|
||||
truncated_inputs[i] = prompt
|
||||
|
||||
return truncated_inputs
|
||||
|
||||
def _get_input_ids_and_labels_pretrain(self, batch_prompt: List[str]) -> Tuple[List[torch.LongTensor]]:
|
||||
"""
|
||||
Get input_ids and labels for pretrain data.
|
||||
We only need batch_prompt because for pretain dataset, we don't need to predict new tokens.
|
||||
|
||||
Args:
|
||||
batch_prompt: A batch of prompt.
|
||||
|
||||
Returns:
|
||||
Input_ids and labels for the given batch.
|
||||
|
||||
"""
|
||||
input_ids_list = []
|
||||
labels_list = []
|
||||
bytes_list = []
|
||||
|
||||
for input in batch_prompt:
|
||||
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
|
||||
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
|
||||
# After all, the rest of the original string doesn't need to be tokenized at the first place.
|
||||
ratio = [16, 8, 4, 2, 1]
|
||||
tokenized = None
|
||||
for r in ratio:
|
||||
tokenized = self.tokenizer(
|
||||
[input[0 : len(input) // r]], truncation=True, max_length=self.model_max_length, return_tensors="pt"
|
||||
)
|
||||
if tokenized.input_ids.size(1) >= self.model_max_length:
|
||||
break
|
||||
|
||||
input_ids = copy.deepcopy(tokenized["input_ids"])[0]
|
||||
target_ids = copy.deepcopy(input_ids)
|
||||
|
||||
string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True)
|
||||
|
||||
bytes_list.append(len(string.encode("utf-8")))
|
||||
|
||||
input_ids_list.append(input_ids)
|
||||
labels_list.append(target_ids)
|
||||
|
||||
return input_ids_list, labels_list, bytes_list
|
||||
|
||||
def _get_input_ids_and_labels(
|
||||
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool
|
||||
) -> Tuple[List[torch.LongTensor]]:
|
||||
"""
|
||||
Get input_ids and labels for the given data.
|
||||
|
||||
Args:
|
||||
batch_prompt: A batch of prompt.
|
||||
batch_target: A batch of target.
|
||||
|
||||
Returns:
|
||||
Input_ids and labels for the given batch.
|
||||
|
||||
"""
|
||||
if pretrain:
|
||||
return self._get_input_ids_and_labels_pretrain(batch_prompt)
|
||||
|
||||
input_ids_list = []
|
||||
labels_list = []
|
||||
|
||||
for input, targets in zip(batch_prompt, batch_target):
|
||||
for target in targets:
|
||||
# TODO: Improve the labeling process. Should annotate the border by adding special tokens.
|
||||
target_tokenized = self.tokenizer(
|
||||
[target], truncation=True, max_length=self.model_max_length, return_tensors="pt"
|
||||
)
|
||||
|
||||
# Get prompt with length model_max_length - len(target_tokenized).
|
||||
# Reserve some space for target answer tokens using max_new_tokens.
|
||||
# This will generate the correct start_idx and end_idx.
|
||||
max_new_tokens = target_tokenized["input_ids"][0].size(0)
|
||||
prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens)[0]
|
||||
input_tokenized = self.tokenizer(
|
||||
[prompt_with_correct_length],
|
||||
truncation=True,
|
||||
max_length=self.model_max_length - max_new_tokens,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
target_tokenized = self.tokenizer(
|
||||
[prompt_with_correct_length + target],
|
||||
truncation=True,
|
||||
max_length=self.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
start_idx = input_tokenized["input_ids"][0].size(0)
|
||||
end_idx = target_tokenized["input_ids"][0].size(0)
|
||||
|
||||
# Sometimes if the target is only an option such as A, B, C and D, the length of input_tokenized is equal to the length of target_tokenized, so we need -1.
|
||||
# This is caused by the different behavior of tokenizers.
|
||||
# For example, the tokenizer for Baichuan and Llama will cause such problem in a plain prompt setting.
|
||||
# The length of the tokenized sequences for prompt "Answer: " and "Answer: A" is the same.
|
||||
# Baichuan: [29394, 31143, 31106] [29394, 31143, 703]
|
||||
# Llama: [673, 29901, 29871] [673, 29901, 319]
|
||||
# The length for sequence "prompt" and "prompt + A" is equal.
|
||||
# For ChatGLM, the length of the tokenized sequences is different.
|
||||
# ChatGLM: [16583, 12] [16583, 12, 167]
|
||||
|
||||
if start_idx == end_idx:
|
||||
start_idx -= 1
|
||||
|
||||
input_ids = copy.deepcopy(target_tokenized["input_ids"])[0]
|
||||
target_ids = copy.deepcopy(input_ids)
|
||||
|
||||
mask = torch.zeros_like(target_ids, dtype=torch.bool)
|
||||
mask[start_idx:end_idx] = True
|
||||
|
||||
target_ids[~mask] = IGNORE_INDEX
|
||||
|
||||
input_ids_list.append(input_ids)
|
||||
labels_list.append(target_ids)
|
||||
|
||||
return input_ids_list, labels_list, None
|
||||
|
||||
def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
|
||||
"""
|
||||
Infer the given data.
|
||||
This function will call self.generate() to get model outputs and also self.model() to get logits.
|
||||
|
||||
Args:
|
||||
data: The data for inference.
|
||||
inference_kwargs: Arguments for inference.
|
||||
debug: Whether to display generated prompt for debugging.
|
||||
|
||||
Returns:
|
||||
Inference results.
|
||||
|
||||
"""
|
||||
calculate_loss = inference_kwargs["calculate_loss"]
|
||||
classes = inference_kwargs["all_classes"]
|
||||
language = inference_kwargs["language"]
|
||||
pretrain = inference_kwargs["pretrain"]
|
||||
max_new_tokens = inference_kwargs["max_new_tokens"]
|
||||
few_shot_data = inference_kwargs.get("few_shot_data", None)
|
||||
|
||||
# Some classification questions' options are texts not a single letter such as A, B, C and D.
|
||||
# If the text length is greater than 1, we won't calculate loss over choices.
|
||||
if classes is not None and any(len(c) > 1 for c in classes):
|
||||
classes = None
|
||||
|
||||
self.choices = classes
|
||||
self.indices_for_choices = None
|
||||
if self.choices:
|
||||
# Get indices for each choice
|
||||
self._get_choices_indices(language)
|
||||
|
||||
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
|
||||
|
||||
bar = tqdm(
|
||||
range(math.ceil(len(data) / self.batch_size)),
|
||||
desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
answers = copy.deepcopy(data)
|
||||
for i in range(0, len(data), self.batch_size):
|
||||
batch = data[i : i + self.batch_size]
|
||||
batch_prompt, batch_target = get_batch_prompt(
|
||||
self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length
|
||||
)
|
||||
|
||||
if is_rank_0() and debug and i == 0:
|
||||
self.logger.info(
|
||||
f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}"
|
||||
)
|
||||
self.logger.info("-" * 120)
|
||||
self.logger.info("An example prompt and prompt with target is:")
|
||||
self.logger.info("-" * 120)
|
||||
self.logger.info(batch_prompt[0])
|
||||
self.logger.info("-" * 120)
|
||||
self.logger.info(batch_prompt[0] + batch_target[0][0])
|
||||
|
||||
if not pretrain:
|
||||
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
|
||||
|
||||
if calculate_loss:
|
||||
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
|
||||
batch_prompt, batch_target, pretrain
|
||||
)
|
||||
|
||||
probs = []
|
||||
if self.indices_for_choices:
|
||||
scores = scores.to(torch.float32)
|
||||
# If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample.
|
||||
# Otherwise this will violate the single-choice setting.
|
||||
|
||||
if calculate_loss:
|
||||
labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))]
|
||||
|
||||
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
|
||||
|
||||
probs = torch.nn.functional.softmax(scores, dim=-1).numpy().tolist()
|
||||
probs = [
|
||||
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
|
||||
]
|
||||
|
||||
for j in range(len(batch_prompt)):
|
||||
if not pretrain:
|
||||
answers[i + j]["output"] = batch_decodes[j].strip()
|
||||
|
||||
if isinstance(scores, torch.Tensor):
|
||||
answers[i + j]["softmax_over_choices"] = probs[j]
|
||||
|
||||
if calculate_loss:
|
||||
answers[i + j]["loss_over_choices"] = loss_over_choices[j]
|
||||
|
||||
if calculate_loss:
|
||||
answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
|
||||
|
||||
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
|
||||
# However, loss (which is per sample loss) suffices for most cases.
|
||||
answers[i + j]["loss_sum"] = batch_losses[j]
|
||||
answers[i + j]["token_num"] = batch_target_token_nums[j]
|
||||
|
||||
if batch_bytes_nums:
|
||||
answers[i + j]["byte_num"] = batch_bytes_nums[j]
|
||||
|
||||
bar.update()
|
||||
|
||||
return answers
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
|
||||
"""Generate results given a list of inputs and get logits of the first new token over choices.
|
||||
|
||||
Args:
|
||||
inputs: A list of strings.
|
||||
max_new_tokens: Max new tokens for generation.
|
||||
kwargs: Key arguments for generation
|
||||
|
||||
Returns:
|
||||
A list of generated strings and logits over choices.
|
||||
|
||||
Note:
|
||||
Currently the function only returns the logits of the first new token.
|
||||
It is used for single choice question.
|
||||
For multiple choices question, please avoid using the loss over choices.
|
||||
You should set argument choices as None in self.inference().
|
||||
|
||||
"""
|
||||
truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens)
|
||||
|
||||
encoded_inputs = self.tokenizer(
|
||||
truncated_inputs,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=False,
|
||||
max_length=self.model_max_length - max_new_tokens,
|
||||
).to(torch.cuda.current_device())
|
||||
|
||||
# Set output_scores=True to get prediction scores.
|
||||
outputs = self.model.generate(
|
||||
**encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs
|
||||
)
|
||||
|
||||
# We only need to decode predicted tokens.
|
||||
sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :]
|
||||
|
||||
scores = []
|
||||
if self.indices_for_choices:
|
||||
# If the question is a single-choice question, we will return the scores of specific indices for first predicted token.
|
||||
# The indices are the tokenization results of the options for the single-choice question.
|
||||
# For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D.
|
||||
for option_indices in self.indices_for_choices:
|
||||
scores.append(outputs.scores[0][:, option_indices].detach().cpu())
|
||||
|
||||
scores = torch.max(torch.stack(scores), dim=0)[0]
|
||||
|
||||
decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
||||
|
||||
return decoded_sequences, scores
|
||||
|
||||
@torch.no_grad()
|
||||
def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]:
|
||||
"""
|
||||
Calculate loss only on target tokens.
|
||||
|
||||
Args:
|
||||
batch: A batch of prompt without target answer.
|
||||
batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
|
||||
|
||||
Returns:
|
||||
Loss.
|
||||
|
||||
"""
|
||||
|
||||
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
|
||||
# We don't need to generate new tokens.
|
||||
# Target answer's length is usually << model_max_length, but we still call it in case.
|
||||
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
|
||||
if not pretrain:
|
||||
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
|
||||
|
||||
# Get the number of target answers for different questions
|
||||
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
|
||||
|
||||
input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain)
|
||||
|
||||
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
|
||||
# We will generate new batches.
|
||||
losses = []
|
||||
target_token_nums = []
|
||||
|
||||
batched_input_ids = [
|
||||
input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
|
||||
]
|
||||
batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
|
||||
|
||||
for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
|
||||
losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
|
||||
losses.extend(losses_per_batch)
|
||||
target_token_nums.extend(target_token_num_per_batch)
|
||||
|
||||
start_indice = 0
|
||||
losses_per_sample = []
|
||||
|
||||
target_token_nums_per_sample = []
|
||||
bytes_nums_per_sample = []
|
||||
for length in batch_target_nums:
|
||||
losses_per_sample.append(losses[start_indice : start_indice + length])
|
||||
target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
|
||||
|
||||
if bytes_list:
|
||||
bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length])
|
||||
|
||||
start_indice += length
|
||||
|
||||
if bytes_list:
|
||||
return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample
|
||||
|
||||
return losses_per_sample, target_token_nums_per_sample, None
|
||||
|
||||
|
||||
class HuggingFaceCausalLM(HuggingFaceModel):
|
||||
"""
|
||||
Model wrapper around HuggingFace AutoModelForCausalLM models.
|
||||
|
||||
Args:
|
||||
path: The path to a HuggingFace model.
|
||||
model_max_length: The maximum sequence length of the model.
|
||||
tokenizer_path: The path to the tokenizer.
|
||||
tokenizer_kwargs: Keyword arguments for the tokenizer.
|
||||
peft_path: The name or path to the HuggingFace's PEFT model.
|
||||
model_kwargs: Keyword arguments for the model.
|
||||
prompt_template: The model's prompt template.
|
||||
batch_size: Batch size for inference.
|
||||
logger: Logger for the model.
|
||||
|
||||
"""
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
|
||||
"""
|
||||
Load model.
|
||||
|
||||
Args:
|
||||
path: The path to the model.
|
||||
model_kwargs: Keyword arguments for the model.
|
||||
peft_path: The path to the peft model.
|
||||
|
||||
"""
|
||||
|
||||
if "torch_dtype" in model_kwargs:
|
||||
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
|
||||
|
||||
if "config" in model_kwargs:
|
||||
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
|
||||
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
self.model.eval()
|
@@ -0,0 +1,4 @@
|
||||
from .conversation import Conversation, get_batch_prompt, prompt_templates
|
||||
from .utilities import get_json_list, is_rank_0, jdump, jload
|
||||
|
||||
__all__ = ["Conversation", "prompt_templates", "get_batch_prompt", "is_rank_0", "jload", "jdump", "get_json_list"]
|
231
applications/ColossalEval/colossal_eval/utils/conversation.py
Normal file
231
applications/ColossalEval/colossal_eval/utils/conversation.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import dataclasses
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
ADD_BOS_EOS_TOKEN = auto()
|
||||
ALPACA = auto()
|
||||
PLAIN = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.ADD_BOS_EOS_TOKEN
|
||||
sep: str = "</s>"
|
||||
|
||||
def clear(self):
|
||||
self.messages = []
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + "<s>" + message + self.sep
|
||||
else:
|
||||
ret += role + ": " + "<s>"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ALPACA:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ":\n" + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.PLAIN:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += message
|
||||
else:
|
||||
ret += ""
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def get_prompt_with_target(self, target):
|
||||
prompt = self.get_prompt()
|
||||
prompt_with_target = []
|
||||
|
||||
# Some dataset provides multiple target answers.
|
||||
# This will make it difficult when we calculate loss.
|
||||
# We convert target into list[str] first if the question only has one target answer.
|
||||
target_answers = []
|
||||
if isinstance(target, str):
|
||||
target_answers = [target]
|
||||
else:
|
||||
target_answers = target
|
||||
|
||||
for target_answer in target_answers:
|
||||
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
|
||||
prompt_with_target.append(prompt + target_answer)
|
||||
elif self.sep_style == SeparatorStyle.ALPACA:
|
||||
prompt_with_target.append(prompt + target_answer)
|
||||
elif self.sep_style == SeparatorStyle.PLAIN:
|
||||
prompt_with_target.append(prompt + target_answer)
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
return prompt_with_target
|
||||
|
||||
def save_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + "<s>" + message + "</s>\n"
|
||||
else:
|
||||
ret += role + ": " + "<s>"
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep_style": self.sep_style,
|
||||
"sep": self.sep,
|
||||
}
|
||||
|
||||
|
||||
def get_few_shot_prefix(
|
||||
conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int
|
||||
) -> str:
|
||||
"""
|
||||
Get few shot prefix.
|
||||
|
||||
Args:
|
||||
conv: Conversation template.
|
||||
few_shot_examples: Few shot examples to generate few shot prompt prefix.
|
||||
|
||||
Returns:
|
||||
Few shot prompt prefix.
|
||||
"""
|
||||
|
||||
if language == "English":
|
||||
few_shot_prefix = f"The following are answers for questions in an exam.\n\n"
|
||||
elif language == "Chinese":
|
||||
few_shot_prefix = f"以下是考试中各个问题的答案。\n\n"
|
||||
|
||||
output = None
|
||||
for i in range(len(few_shot_data)):
|
||||
few_shot_prefix = few_shot_prefix + few_shot_data[i] + "\n\n"
|
||||
|
||||
if len(tokenizer([few_shot_prefix]).input_ids[0]) <= max_tokens:
|
||||
output = few_shot_prefix
|
||||
else:
|
||||
break
|
||||
|
||||
return output if output is not None else few_shot_prefix
|
||||
|
||||
|
||||
def get_batch_prompt(
|
||||
conv: Conversation,
|
||||
batch: List[Dict],
|
||||
few_shot_data: List[str],
|
||||
tokenizer: Optional[AutoTokenizer],
|
||||
language: Optional[str],
|
||||
model_max_length: Optional[int],
|
||||
) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
Get batch prompt and target.
|
||||
|
||||
Args:
|
||||
conv: Conversation template.
|
||||
batch: Batch data to generate prompt from.
|
||||
few_shot_data: Few shot data to generate few shot prompt prefix.
|
||||
|
||||
Returns:
|
||||
Tuple containg batch prompt and target.
|
||||
|
||||
"""
|
||||
|
||||
batch_prompt = []
|
||||
batch_target = []
|
||||
|
||||
if isinstance(batch[0], dict):
|
||||
for b in batch:
|
||||
few_shot_prefix = ""
|
||||
if few_shot_data is not None:
|
||||
# For few-shot, only need input. Otherwise use instruction (in AGIEval).
|
||||
query_text = b["input"] if b.get("input", "") != "" else b["instruction"]
|
||||
|
||||
if isinstance(b["target"], str):
|
||||
zero_shot_prompt = query_text + b["target"]
|
||||
max_tokens = model_max_length - len(tokenizer([zero_shot_prompt]).input_ids[0])
|
||||
else:
|
||||
raise Exception("When using few-shot, target answer should be a string.")
|
||||
|
||||
few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens)
|
||||
else:
|
||||
query_text = b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
|
||||
|
||||
conv.append_message(conv.roles[0], few_shot_prefix + query_text)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
|
||||
batch_prompt.append(conv.get_prompt())
|
||||
|
||||
target = b["target"]
|
||||
if isinstance(b["target"], str):
|
||||
target = [target]
|
||||
|
||||
batch_target.append(target)
|
||||
|
||||
conv.clear()
|
||||
|
||||
return batch_prompt, batch_target
|
||||
|
||||
|
||||
conv_coati = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
|
||||
sep="</s>",
|
||||
)
|
||||
|
||||
conv_alpaca = Conversation(
|
||||
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
|
||||
roles=("### Instruction", "### Response"),
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ALPACA,
|
||||
sep="\n\n",
|
||||
)
|
||||
|
||||
conv_plain = Conversation(
|
||||
system="",
|
||||
roles=("", ""),
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.PLAIN,
|
||||
sep="",
|
||||
)
|
||||
|
||||
prompt_templates = {"coati": conv_coati, "alpaca": conv_alpaca, "plain": conv_plain}
|
62
applications/ColossalEval/colossal_eval/utils/utilities.py
Normal file
62
applications/ColossalEval/colossal_eval/utils/utilities.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def _make_w_io_base(f, mode: str):
|
||||
if not isinstance(f, io.IOBase):
|
||||
f_dirname = os.path.dirname(f)
|
||||
if f_dirname != "":
|
||||
os.makedirs(f_dirname, exist_ok=True)
|
||||
f = open(f, mode=mode, encoding="utf-8")
|
||||
return f
|
||||
|
||||
|
||||
def _make_r_io_base(f, mode: str):
|
||||
if not isinstance(f, io.IOBase):
|
||||
f = open(f, mode=mode, encoding="utf-8")
|
||||
return f
|
||||
|
||||
|
||||
def jdump(obj, f, mode="w", indent=4, default=str):
|
||||
"""
|
||||
Dump a str or dictionary to a file in json format.
|
||||
|
||||
Args:
|
||||
obj: An object to be written.
|
||||
f: A string path to the location on disk.
|
||||
mode: Mode for opening the file.
|
||||
indent: Indent for storing json dictionaries.
|
||||
default: A function to handle non-serializable entries; defaults to `str`.
|
||||
|
||||
"""
|
||||
f = _make_w_io_base(f, mode)
|
||||
if isinstance(obj, (dict, list)):
|
||||
json.dump(obj, f, indent=indent, default=default, ensure_ascii=False)
|
||||
elif isinstance(obj, str):
|
||||
f.write(obj)
|
||||
else:
|
||||
raise ValueError(f"Unexpected type: {type(obj)}")
|
||||
f.close()
|
||||
|
||||
|
||||
def jload(f, mode="r"):
|
||||
"""Load a .json file into a dictionary."""
|
||||
f = _make_r_io_base(f, mode)
|
||||
jdict = json.load(f)
|
||||
f.close()
|
||||
return jdict
|
||||
|
||||
|
||||
def get_json_list(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
json_list = []
|
||||
for line in f:
|
||||
json_list.append(json.loads(line if line != "null" else line))
|
||||
return json_list
|
Reference in New Issue
Block a user