mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-05 06:58:09 +00:00
[ColossalEval] Support GSM, Data Leakage Evaluation and Tensor Parallel (#5169)
* Support GSM, Data Leakage Evaluation and Tensor Parallel * remove redundant code and update inference.py in examples/gpt_evaluation --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
This commit is contained in:
parent
b07a6f4e27
commit
cefdc32615
applications/ColossalEval
@ -37,7 +37,7 @@
|
||||
- [Citations](#citations)
|
||||
|
||||
## Overview
|
||||
[ColossalEval](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval) is a project which provides a uniform pipeline to help evaluate language models on different public dataset or your own dataset using both classic metrics and the help from GPTs. More details can be found in the following sections.
|
||||
[ColossalEval](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval) is a project which provides a uniform pipeline to help evaluate language models on different public dataset or your own dataset using both classic metrics and the help from GPTs. Currently we support AGIEval, CEval, CMMLU, CValues, GAOKAO-Bench, GSM8K, LongBench, MMLU, MtBench and SafetyBench. More details can be found in the following sections.
|
||||
|
||||
## Leaderboard
|
||||
|
||||
@ -101,7 +101,7 @@ The evaluation process involves 2 steps which are `inference` and `evaluation`.
|
||||
|
||||
### Inference
|
||||
|
||||
The inference process consists of two parts.
|
||||
The inference process consists of two parts. We now support tensor parallel inference for large models using [ShardFormer](colossalai/shardformer) in the [example](applications/ColossalEval/examples/dataset_evaluation/inference.py) script.
|
||||
1. Preprocess and convert the original dataset.
|
||||
2. Config your tokenizer and model arguments to perform zero-shot or few-shot prompting.
|
||||
|
||||
@ -193,7 +193,7 @@ In this step, you will configure your tokenizer and model arguments to infer on
|
||||
|
||||
A config file consists of two parts.
|
||||
1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
|
||||
2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench and LongBench and few-shot on dataset MMLU, CMMLU and AGIEval. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`.
|
||||
2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench, GSM8K and LongBench and few-shot on dataset MMLU, CMMLU AGIEval and GSM8K. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`.
|
||||
|
||||
Once you have all config ready, the program will run inference on all the given datasets on all the given models.
|
||||
|
||||
@ -236,17 +236,20 @@ An example config using model class `HuggingFaceCausalLM` and dataset class `CMM
|
||||
|
||||
Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.
|
||||
|
||||
> For GSM8K dataset, you can set additional flags `load_train` or `load_reference` for dataset configuration as true and during the inference process, the program will calculate loss summation over all tokens for each data sample. During the evaluation process, you can use metric `loss_over_all_tokens` to calculate the overall loss and use it for data leakage evaluation.
|
||||
|
||||
#### How to Use
|
||||
An example script can be the following. The `configs/dataset_evaluation/inference.py` is the same in all examples provided.
|
||||
|
||||
```shell
|
||||
torchrun --nproc_per_node=1 inference.py \
|
||||
torchrun --nproc_per_node=4 inference.py \
|
||||
--config "path to config file" \
|
||||
--load_dataset \
|
||||
--tp_size 2 \
|
||||
--inference_save_path "path to save inference results"
|
||||
```
|
||||
|
||||
You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`.
|
||||
You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size.
|
||||
|
||||
### Evaluation
|
||||
|
||||
@ -371,11 +374,13 @@ To make it more easier to set the config, you only need to specify all metrics y
|
||||
- `classification_score`: Calculate classification score between prediction and reference. It determines whether the ouput(a class) is equal to the reference. It is used in Longbench.
|
||||
- `code_sim_score`: Calculate similarity score between prediction and reference. It is used in Longbench.
|
||||
- `count_score`: Calculate count score between prediction and reference. It determines whether the ouput(number of given passages) is equal to the reference. It is used in Longbench.
|
||||
- `gsm_accuracy`: Calculate scores between prediction and reference.. It is used in GSM8K.
|
||||
- `perplexity`: Calculate perplexity. The formula is $ perplexity = \frac{1}{n} \sum_i e^{loss_i} $ where $n$ is the number of samples and $ loss_i $ is the average loss for sample $ i $. It can be used in all dataset.
|
||||
- `ppl_score`: Calculate perplexity score. The formula is $ ppl\_score = \frac{1}{n} \sum_i e^{-loss_i} $ where $n$ is the number of samples and $ loss_i $ is the average loss for sample $ i $. It can be used in all dataset.
|
||||
- `ppl_score_over_choices`: Calculate perplexity score over choices. The formula is $ ppl\_score\_over\_choices= \frac{1}{n} \sum_i e^{-loss\_over\_choices_i} $ where $n$ is the number of samples and $ loss\_over\_choices_i $ is the loss on the first predicted token for sample $ i $. It can be used in all dataset that contains single-choice questions.
|
||||
- `per_byte_perplexity`: Calculate per byte perplexity. The formula is $ \frac{1}{n} \sum_i e^{\frac{loss_i}{byte_i}} $ where $n$ is the number of samples, $ loss_i $ is the total loss for sample $ i $ and $ byte_i $ is the number of bytes sample $ i $ occupies. It can be used in all dataset.
|
||||
- `per_byte_ppl_score`: Calculate per byte perplexity score. The formula is $ \frac{1}{n} \sum_i e^{-\frac{loss_i}{byte_i}} $ where $n$ is the number of samples, $ loss_i $ is the total loss for sample $ i $ and $ byte_i $ is the number of bytes sample $ i $ occupies. It can be used in all dataset.
|
||||
- `loss_over_all_tokens`: Calculate loss over all tokens. The formula is $ loss\_over\_all\_tokens = \frac{1}{n} \sum_i loss_i $ where $n$ is the total number of tokens of the dataset and $ loss_i $ is the loss summation for sample $ i $ over all tokens and $ \sum_i loss_i $ is the loss summation for all samples. It can be used in all dataset.
|
||||
|
||||
We use `combined_single_choice_accuracy` and `first_token_logit` in the leaderboard.
|
||||
|
||||
@ -520,6 +525,15 @@ year={2023}
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
|
||||
@misc{xu2023cvalues,
|
||||
title={CValues: Measuring the Values of Chinese Large Language Models from Safety to Responsibility},
|
||||
author={Guohai Xu and Jiayi Liu and Ming Yan and Haotian Xu and Jinghui Si and Zhuoran Zhou and Peng Yi and Xing Gao and Jitao Sang and Rong Zhang and Ji Zhang and Chao Peng and Fei Huang and Jingren Zhou},
|
||||
year={2023},
|
||||
eprint={2307.09705},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
|
||||
@inproceedings{Zhang2023EvaluatingTP,
|
||||
title={Evaluating the Performance of Large Language Models on GAOKAO Benchmark},
|
||||
author={Xiaotian Zhang and Chunyang Li and Yi Zong and Zhengyu Ying and Liang He and Xipeng Qiu},
|
||||
@ -542,6 +556,20 @@ year={2023}
|
||||
year={2021}
|
||||
}
|
||||
|
||||
@article{zhang2023safetybench,
|
||||
title={SafetyBench: Evaluating the Safety of Large Language Models with Multiple Choice Questions},
|
||||
author={Zhexin Zhang and Leqi Lei and Lindong Wu and Rui Sun and Yongkang Huang and Chong Long and Xiao Liu and Xuanyu Lei and Jie Tang and Minlie Huang},
|
||||
journal={arXiv preprint arXiv:2309.07045},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
@article{cobbe2021training,
|
||||
title={Training verifiers to solve math word problems},
|
||||
author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and others},
|
||||
journal={arXiv preprint arXiv:2110.14168},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
@article{hendrycks2021ethics,
|
||||
title={Aligning AI With Shared Human Values},
|
||||
author={Dan Hendrycks and Collin Burns and Steven Basart and Andrew Critch and Jerry Li and Dawn Song and Jacob Steinhardt},
|
||||
@ -558,4 +586,12 @@ year={2023}
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
|
||||
@misc{wei2023skywork,
|
||||
title={Skywork: A More Open Bilingual Foundation Model},
|
||||
author={Tianwen Wei and Liang Zhao and Lichang Zhang and Bo Zhu and Lijie Wang and Haihua Yang and Biye Li and Cheng Cheng and Weiwei Lü and Rui Hu and Chenxia Li and Liu Yang and Xilin Luo and Xuejie Wu and Lunan Liu and Wenjun Cheng and Peng Cheng and Jianhao Zhang and Xiaoyu Zhang and Lei Lin and Xiaokun Wang and Yutuan Ma and Chuanhai Dong and Yanqi Sun and Yifu Chen and Yongyi Peng and Xiaojuan Liang and Shuicheng Yan and Han Fang and Yahui Zhou},
|
||||
year={2023},
|
||||
eprint={2310.19341},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
||||
|
@ -5,6 +5,7 @@ from .cmmlu import CMMLUDataset
|
||||
from .colossalai import ColossalDataset
|
||||
from .cvalues import CValuesDataset
|
||||
from .gaokaobench import GaoKaoBenchDataset
|
||||
from .gsm import GSMDataset
|
||||
from .longbench import LongBenchDataset
|
||||
from .mmlu import MMLUDataset
|
||||
from .mtbench import MTBenchDataset
|
||||
@ -24,4 +25,5 @@ __all__ = [
|
||||
"SafetyBenchENDataset",
|
||||
"SafetyBenchZHDataset",
|
||||
"CValuesDataset",
|
||||
"GSMDataset",
|
||||
]
|
||||
|
@ -99,11 +99,20 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict
|
||||
|
||||
# process few-shot raw_prompts
|
||||
def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=False):
|
||||
demostrations = []
|
||||
demostration_en = "Here are the answers for the problems in the exam."
|
||||
demostration_zh = "以下是考试中各个问题的答案。"
|
||||
|
||||
if dataset_name in english_qa_datasets or dataset_name in english_cloze_datasets:
|
||||
demostrations.append(demostration_en)
|
||||
elif dataset_name in chinese_qa_datasets or dataset_name in chinese_cloze_datasets:
|
||||
demostrations.append(demostration_zh)
|
||||
|
||||
skip_passage = False
|
||||
if dataset_name == "sat-en-without-passage":
|
||||
skip_passage = True
|
||||
dataset_name = "sat-en"
|
||||
demostrations = []
|
||||
|
||||
# read the prompts by context and explanation
|
||||
context_row = [0, 1, 3, 5, 7, 9]
|
||||
explanation_row = [0, 2, 4, 6, 8, 10]
|
||||
@ -153,7 +162,7 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F
|
||||
if chat_mode:
|
||||
demostrations.append((question_input,))
|
||||
else:
|
||||
demostrations.append(question_input + "\n")
|
||||
demostrations.append(question_input)
|
||||
|
||||
return demostrations
|
||||
|
||||
@ -178,7 +187,9 @@ class AGIEvalDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
|
||||
files = glob.glob(os.path.join(path, "*.jsonl"))
|
||||
|
@ -12,8 +12,8 @@ class BaseDataset:
|
||||
logger: Logger for the dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, path, logger, few_shot):
|
||||
self.dataset = self.load(path, logger, few_shot)
|
||||
def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False):
|
||||
self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference)
|
||||
|
||||
def save(self, save_path):
|
||||
"""Save the converted dataset"""
|
||||
|
@ -71,8 +71,8 @@ default_inference_kwargs = {
|
||||
}
|
||||
|
||||
|
||||
def get_few_shot_data(data: List[Dict]):
|
||||
few_shot_data = []
|
||||
def get_few_shot_data(data: List[Dict], subject):
|
||||
few_shot_data = [f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。"]
|
||||
for i in data:
|
||||
few_shot_data.append(i["input"] + i["target"])
|
||||
return few_shot_data
|
||||
@ -86,7 +86,9 @@ class CEvalDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
@ -105,7 +107,7 @@ class CEvalDataset(BaseDataset):
|
||||
|
||||
if split == "test" and few_shot:
|
||||
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
|
||||
dataset["dev"][subject]["data"]
|
||||
dataset["dev"][subject]["data"], subject
|
||||
)
|
||||
|
||||
with open(file_dir, encoding="utf-8") as f:
|
||||
|
@ -86,8 +86,8 @@ default_inference_kwargs = {
|
||||
}
|
||||
|
||||
|
||||
def get_few_shot_data(data: List[Dict]):
|
||||
few_shot_data = []
|
||||
def get_few_shot_data(data: List[Dict], subject):
|
||||
few_shot_data = [f"以下是关于{subject}的单项选择题,请直接给出正确答案的选项。"]
|
||||
for i in data:
|
||||
few_shot_data.append(i["input"] + i["target"])
|
||||
return few_shot_data
|
||||
@ -101,7 +101,9 @@ class CMMLUDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
@ -120,7 +122,7 @@ class CMMLUDataset(BaseDataset):
|
||||
|
||||
if split == "test" and few_shot:
|
||||
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
|
||||
dataset["dev"][subject]["data"]
|
||||
dataset["dev"][subject]["data"], subject
|
||||
)
|
||||
|
||||
with open(file_dir, encoding="utf-8") as f:
|
||||
|
@ -69,7 +69,9 @@ class GaoKaoBenchDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]:
|
||||
files = os.listdir(os.path.join(path, "data", category))
|
||||
|
140
applications/ColossalEval/colossal_eval/dataset/gsm.py
Normal file
140
applications/ColossalEval/colossal_eval/dataset/gsm.py
Normal file
@ -0,0 +1,140 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
from colossal_eval.utils import get_json_list
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
few_shot_prompt = """Question: In 2004, there were 60 kids at a cookout. In 2005, half the number of kids came to the cookout as compared to 2004. In 2006, 2/3 as many kids came to the cookout as in 2005. How many kids came to the cookout in 2006?
|
||||
Let's think step by step
|
||||
In 2005, 60/2=30 kids came to the cookout.
|
||||
In 2006, 30/3*2=20 kids came to the cookout.
|
||||
The answer is 20
|
||||
|
||||
Question: Zilla spent 7% of her monthly earnings on rent, half of it on her other monthly expenses, and put the rest in her savings. If she spent $133 on her rent, how much does she deposit into her savings account in a month?
|
||||
Let's think step by step
|
||||
Since $133 is equal to 7% of her earnings, then 1% is equal to $133/7 = $19.
|
||||
The total monthly earning of Zilla is represented by 100%, so $19 x 100 = $1900 is her monthly earnings.
|
||||
So, $1900/2 = $950 is spent on her other monthly expenses.
|
||||
The total amount spent on the rent and other monthly expenses is $133 + $950 = $1083.
|
||||
Hence, she saves $1900 - $1083 = $817 per month.
|
||||
The answer is 817
|
||||
|
||||
Question: If Buzz bought a pizza with 78 slices at a restaurant and then decided to share it with the waiter in the ratio of 5:8, with Buzz's ratio being 5, what's twenty less the number of slices of pizza that the waiter ate?
|
||||
Let's think step by step
|
||||
The total ratio representing the slices of pizza that Buzz bought is 5+8=13
|
||||
If he shared the slices of pizza with the waiter, the waiter received a fraction of 8/13 of the total number of slices, which totals 8/13 * 78 = 48 slices
|
||||
Twenty less the number of slices of pizza that the waiter ate is 48-20 = 28
|
||||
The answer is 28
|
||||
|
||||
Question: Jame gets a raise to $20 per hour and works 40 hours a week. His old job was $16 an hour for 25 hours per week. How much more money does he make per year in his new job than the old job if he works 52 weeks a year?
|
||||
Let's think step by step
|
||||
He makes 20*40=$800 per week
|
||||
He used to make 16*25=$400 per week
|
||||
So his raise was 800-400=$400 per week
|
||||
So he makes 400*52=$20,800 per year more
|
||||
The answer is 20800
|
||||
|
||||
Question: Mr. Gardner bakes 20 cookies, 25 cupcakes, and 35 brownies for his second-grade class of 20 students. If he wants to give each student an equal amount of sweet treats, how many sweet treats will each student receive?
|
||||
Let's think step by step
|
||||
Mr. Gardner bakes a total of 20 + 25 + 35 = 80 sweet treats
|
||||
Each student will receive 80 / 20 = 4 sweet treats
|
||||
The answer is 4
|
||||
|
||||
Question: A used car lot has 24 cars and motorcycles (in total) for sale. A third of the vehicles are motorcycles, and a quarter of the cars have a spare tire included. How many tires are on the used car lot’s vehicles in all?
|
||||
Let's think step by step
|
||||
The used car lot has 24 / 3 = 8 motorcycles with 2 tires each.
|
||||
The lot has 24 - 8 = 16 cars for sale
|
||||
There are 16 / 4 = 4 cars with a spare tire with 5 tires each.
|
||||
The lot has 16 - 4 = 12 cars with 4 tires each.
|
||||
Thus, the used car lot’s vehicles have 8 * 2 + 4 * 5 + 12 * 4 = 16 + 20 + 48 = 84 tires in all.
|
||||
The answer is 84
|
||||
|
||||
Question: Norma takes her clothes to the laundry. She leaves 9 T-shirts and twice as many sweaters as T-shirts in the washer. When she returns she finds 3 sweaters and triple the number of T-shirts. How many items are missing?
|
||||
Let's think step by step
|
||||
Norma left 9 T-shirts And twice as many sweaters, she took 9 * 2= 18 sweaters
|
||||
Adding the T-shirts and sweaters, Norma left 9 + 18 = 27 clothes
|
||||
When she came back, she found 3 sweaters And triple the number of T-shirts, she found 3 * 3 = 9 T-shirts
|
||||
Adding the T-shirts and sweaters, Norma found 3 + 9 = 12 clothes
|
||||
Subtracting the clothes she left from the clothes she found, 27 - 12 = 15 clothes are missing
|
||||
The answer is 15
|
||||
|
||||
Question: Adam has an orchard. Every day for 30 days he picks 4 apples from his orchard. After a month, Adam has collected all the remaining apples, which were 230. How many apples in total has Adam collected from his orchard?
|
||||
Let's think step by step
|
||||
During 30 days Adam picked 4 * 30 = 120 apples.
|
||||
So in total with all the remaining apples, he picked 120 + 230 = 350 apples from his orchard.
|
||||
The answer is 350"""
|
||||
|
||||
default_inference_kwargs = {
|
||||
"calculate_loss": True,
|
||||
"all_classes": None,
|
||||
"language": "English",
|
||||
"pretrain": False,
|
||||
"max_new_tokens": 256,
|
||||
}
|
||||
|
||||
|
||||
def get_few_shot_data():
|
||||
few_shot_data = few_shot_prompt.split("\n\n")
|
||||
# print(few_shot_data)
|
||||
assert len(few_shot_data) == 8
|
||||
|
||||
return few_shot_data
|
||||
|
||||
|
||||
class GSMDataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for GSM dataset.
|
||||
Data source: https://github.com/openai/grade-school-math/tree/master/grade_school_math/data
|
||||
This dataset class will convert the original dataset into the inference dataset.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
dataset = {"test": {}}
|
||||
|
||||
if load_train:
|
||||
dataset["train"] = {}
|
||||
|
||||
if load_reference:
|
||||
dataset["reference"] = {}
|
||||
|
||||
for split in dataset:
|
||||
file_name = f"{split}.jsonl" if split != "reference" else "mock_gsm8k_test.jsonl"
|
||||
file = os.path.join(path, file_name)
|
||||
data = get_json_list(file)
|
||||
subject = "math"
|
||||
|
||||
dataset[split][subject] = {"data": []}
|
||||
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
|
||||
|
||||
if forward_only:
|
||||
dataset[split][subject]["inference_kwargs"]["pretrain"] = True
|
||||
|
||||
if split == "test" and few_shot:
|
||||
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data()
|
||||
|
||||
for question in data:
|
||||
if forward_only:
|
||||
input_string = question["question"] + " " if split != "reference" else question["text"]
|
||||
else:
|
||||
input_string = f"Question: {question['question']}\nLet's think step by step\n"
|
||||
|
||||
data_sample = {
|
||||
"dataset": "gsm",
|
||||
"split": split,
|
||||
"category": subject,
|
||||
"instruction": "",
|
||||
"input": input_string,
|
||||
"output": "",
|
||||
"target": question["answer"] if split != "reference" else "",
|
||||
}
|
||||
|
||||
dataset[split][subject]["data"].append(data_sample)
|
||||
|
||||
return dataset
|
@ -16,8 +16,8 @@ default_inference_kwargs = {
|
||||
}
|
||||
|
||||
|
||||
def get_few_shot_data(data: List[Dict]):
|
||||
few_shot_data = []
|
||||
def get_few_shot_data(data: List[Dict], subject):
|
||||
few_shot_data = [f"The following are multiple choice questions (with answers) about {subject}."]
|
||||
for i in data:
|
||||
few_shot_data.append(i["input"] + i["target"])
|
||||
return few_shot_data
|
||||
@ -31,7 +31,9 @@ class MMLUDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
def load(
|
||||
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
|
||||
) -> List[Dict]:
|
||||
dataset = {"dev": {}, "test": {}}
|
||||
for split in ["dev", "test"]:
|
||||
files = os.listdir(os.path.join(path, split))
|
||||
@ -50,7 +52,7 @@ class MMLUDataset(BaseDataset):
|
||||
|
||||
if split == "test" and few_shot:
|
||||
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
|
||||
dataset["dev"][subject]["data"]
|
||||
dataset["dev"][subject]["data"], subject
|
||||
)
|
||||
|
||||
with open(file_dir, encoding="utf-8") as f:
|
||||
|
@ -6,8 +6,17 @@ import numpy as np
|
||||
import tqdm
|
||||
from colossal_eval.utils import jdump
|
||||
|
||||
import colossal_eval.evaluate.dataset_evaluator.gpt_judge as gpt_helper # noqa
|
||||
|
||||
LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"]
|
||||
LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"]
|
||||
LossBasedMetrics = [
|
||||
"perplexity",
|
||||
"ppl_score",
|
||||
"ppl_score_over_choices",
|
||||
"per_byte_perplexity",
|
||||
"per_byte_ppl_score",
|
||||
"loss_over_all_tokens",
|
||||
]
|
||||
CombinedMetrics = ["combined_single_choice_accuracy"]
|
||||
GPTMetrics = ["mtbench_single_judge"]
|
||||
OtherMetrics = [
|
||||
@ -23,6 +32,7 @@ OtherMetrics = [
|
||||
"multi_choice_accuracy",
|
||||
"math_equivalence",
|
||||
"single_choice_accuracy",
|
||||
"gsm_accuracy",
|
||||
]
|
||||
|
||||
|
||||
@ -141,7 +151,10 @@ class DatasetEvaluator(object):
|
||||
"""Calculate other metrics."""
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
|
||||
references = [sample["target"] for sample in self.data[category]["data"]]
|
||||
references = [
|
||||
sample["target"] if isinstance(sample["target"], list) else [sample["target"]]
|
||||
for sample in self.data[category]["data"]
|
||||
]
|
||||
predictions = [sample["output"] for sample in self.data[category]["data"]]
|
||||
|
||||
metric_method = eval("metric_helper." + metric)
|
||||
@ -218,6 +231,18 @@ class DatasetEvaluator(object):
|
||||
|
||||
self.evaluation_results["per_byte_ppl_score"][category] = perplexity_score
|
||||
self.evaluation_results["per_byte_ppl_score"]["ALL"] += perplexity_score * weight
|
||||
elif metric == "loss_over_all_tokens":
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]]
|
||||
token_nums = [sample["token_num"][np.argmin(sample["loss_sum"])] for sample in self.data[category]["data"]]
|
||||
perplexity = np.sum(np.array(losses)) / np.sum(np.array(token_nums))
|
||||
|
||||
self.evaluation_results["loss_over_all_tokens"][category] = perplexity
|
||||
self.evaluation_results["loss_over_all_tokens"]["ALL"] += perplexity * weight
|
||||
|
||||
# The number of tokens can be used for normalizing.
|
||||
# See https://github.com/SkyworkAI/Skywork/issues/43#issuecomment-1811733834
|
||||
print(f"{self.model_name} {category} token num: {np.sum(np.array(token_nums))}")
|
||||
|
||||
def _evaluate(self):
|
||||
"""Calculate and return evaluation results"""
|
||||
@ -289,7 +314,10 @@ class DatasetEvaluator(object):
|
||||
self.suggested_categories = {metric: [] for metric in self.metrics}
|
||||
|
||||
for metric in self.metrics:
|
||||
self.suggested_categories[metric] = metric_helper.metrics4subcategory[self.dataset_name][metric]
|
||||
# Train and reference split use same metric as test split.
|
||||
self.suggested_categories[metric] = metric_helper.metrics4subcategory[self.dataset_name.split("_")[0]][
|
||||
metric
|
||||
]
|
||||
if "ALL" in self.suggested_categories[metric]:
|
||||
self.suggested_categories[metric] = self.categories
|
||||
self.metric_total_length[metric] = self.total_length
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Code adapted from https://github.com/THUDM/LongBench/blob/main/metrics.py
|
||||
# Code adapted from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
|
||||
# Code adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/evaluation.py
|
||||
# https://github.com/SkyworkAI/Skywork/blob/main/eval/eval_gsm8k.py
|
||||
|
||||
import difflib
|
||||
import re
|
||||
@ -11,6 +12,11 @@ import jieba
|
||||
from fuzzywuzzy import fuzz
|
||||
from rouge import Rouge
|
||||
|
||||
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||
INVALID_ANS = "[invalid]"
|
||||
ans_re1 = re.compile(r"(\-?[0-9][0-9\.\,]*)")
|
||||
ans_re2 = re.compile(r"=\s*(\$?-?[0-9][0-9\.\,]*)")
|
||||
|
||||
metrics4subcategory = {
|
||||
"pretrain": {
|
||||
"perplexity": ["ALL"],
|
||||
@ -189,6 +195,10 @@ metrics4subcategory = {
|
||||
"cvalues": {"first_token_accuracy": ["ALL"]},
|
||||
"safetybench_zh": {"first_token_accuracy": ["ALL"]},
|
||||
"safetybench_en": {"first_token_accuracy": ["ALL"]},
|
||||
"gsm": {
|
||||
"loss_over_all_tokens": ["ALL"],
|
||||
"gsm_accuracy": ["ALL"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -639,3 +649,61 @@ def f1_zh_score(prediction, reference, **kwargs):
|
||||
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
|
||||
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
|
||||
return _f1_score(prediction_tokens, ground_truth_tokens)
|
||||
|
||||
|
||||
def extract_answer_hf(completion):
|
||||
match = ANS_RE.search(completion)
|
||||
if match:
|
||||
match_str = match.group(1).strip()
|
||||
match_str = match_str.replace(",", "")
|
||||
return eval(match_str)
|
||||
else:
|
||||
return INVALID_ANS
|
||||
|
||||
|
||||
def get_match_str(match, idx):
|
||||
match_str = match[idx]
|
||||
match_str = match_str.replace(",", "")
|
||||
if match_str.endswith("."):
|
||||
match_str = match_str[:-1]
|
||||
if match_str.endswith(".00"):
|
||||
match_str = match_str[:-3]
|
||||
if match_str.endswith(".0"):
|
||||
match_str = match_str[:-2]
|
||||
return match_str
|
||||
|
||||
|
||||
def extract_answer(completion):
|
||||
match1 = re.findall(ans_re1, completion)
|
||||
match2 = re.findall(ans_re2, completion)
|
||||
ans = []
|
||||
if match1:
|
||||
match_str1 = get_match_str(match1, -1)
|
||||
ans.append(match_str1)
|
||||
if match2:
|
||||
match_str2 = get_match_str(match2, -1).replace("$", "")
|
||||
ans.append(match_str2)
|
||||
|
||||
answer = INVALID_ANS
|
||||
try:
|
||||
if len(ans) > 0:
|
||||
answer = eval(ans[-1])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return answer
|
||||
return answer
|
||||
|
||||
|
||||
def is_correct(completion, answer):
|
||||
gold = extract_answer_hf(answer)
|
||||
assert gold != INVALID_ANS, "No ground truth answer found in the document."
|
||||
completion = completion.split("answer is")[-1]
|
||||
return extract_answer(completion) == gold
|
||||
|
||||
|
||||
def gsm_accuracy(prediction, reference, **kwargs):
|
||||
prediction = prediction.split("\n\n\n")[0]
|
||||
prediction = prediction.split("\n\n")[0]
|
||||
prediction = prediction.split("Question:")[0]
|
||||
|
||||
return 1.0 if is_correct(prediction, reference) else 0.0
|
||||
|
@ -10,6 +10,7 @@ from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
from .base import BaseModel
|
||||
|
||||
@ -30,6 +31,7 @@ class HuggingFaceModel(BaseModel):
|
||||
prompt_template: The model's prompt template.
|
||||
batch_size: Batch size for inference.
|
||||
logger: Logger for the model.
|
||||
shard_config: Shard config for tensor parallel.
|
||||
|
||||
"""
|
||||
|
||||
@ -44,6 +46,7 @@ class HuggingFaceModel(BaseModel):
|
||||
prompt_template: Conversation = None,
|
||||
batch_size: int = 1,
|
||||
logger: DistributedLogger = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
super().__init__(
|
||||
path=path,
|
||||
@ -54,7 +57,7 @@ class HuggingFaceModel(BaseModel):
|
||||
)
|
||||
self._load_tokenizer(path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs)
|
||||
|
||||
self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path)
|
||||
self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path, shard_config=shard_config)
|
||||
|
||||
def _get_choices_indices(self, language: str):
|
||||
"""
|
||||
@ -100,7 +103,9 @@ class HuggingFaceModel(BaseModel):
|
||||
# Qwen has an eod token "<|endoftext|>".
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eod_id
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
|
||||
def _load_model(
|
||||
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
|
||||
):
|
||||
"""
|
||||
Load model.
|
||||
|
||||
@ -108,17 +113,29 @@ class HuggingFaceModel(BaseModel):
|
||||
path: The path to the model.
|
||||
model_kwargs: Keyword arguments for the model.
|
||||
peft_path: The path to the peft model.
|
||||
shard_config: Shard config for tensor parallel.
|
||||
|
||||
"""
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
|
||||
if "torch_dtype" in model_kwargs:
|
||||
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
|
||||
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
if "config" in model_kwargs:
|
||||
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
|
||||
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
if shard_config is not None:
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||
shard_former = ShardFormer(shard_config)
|
||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||
self.model.to(torch.cuda.current_device())
|
||||
|
||||
if peft_path is not None:
|
||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||
else:
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
self.model.eval()
|
||||
|
||||
def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> Tuple[List]:
|
||||
@ -152,7 +169,7 @@ class HuggingFaceModel(BaseModel):
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
|
||||
|
||||
lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy()
|
||||
lens = (labels[..., 1:] != IGNORE_INDEX).sum(-1).cpu().numpy()
|
||||
|
||||
loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
|
||||
return loss_sum.tolist(), lens.tolist()
|
||||
@ -239,7 +256,13 @@ class HuggingFaceModel(BaseModel):
|
||||
|
||||
"""
|
||||
if pretrain:
|
||||
return self._get_input_ids_and_labels_pretrain(batch_prompt)
|
||||
batch = []
|
||||
# Concatenate prompt and target answers.
|
||||
# You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space.
|
||||
for p, b in zip(batch_prompt, batch_target):
|
||||
batch.append(p + b[0])
|
||||
|
||||
return self._get_input_ids_and_labels_pretrain(batch)
|
||||
|
||||
input_ids_list = []
|
||||
labels_list = []
|
||||
@ -380,7 +403,7 @@ class HuggingFaceModel(BaseModel):
|
||||
|
||||
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
|
||||
|
||||
probs = torch.nn.functional.softmax(scores, dim=-1).numpy().tolist()
|
||||
probs = scores.numpy().tolist()
|
||||
probs = [
|
||||
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
|
||||
]
|
||||
@ -393,7 +416,7 @@ class HuggingFaceModel(BaseModel):
|
||||
answers[i + j]["output"] = batch_decodes[j].strip()
|
||||
|
||||
if isinstance(scores, torch.Tensor):
|
||||
answers[i + j]["softmax_over_choices"] = probs[j]
|
||||
answers[i + j]["logits_over_choices"] = probs[j]
|
||||
|
||||
if calculate_loss:
|
||||
answers[i + j]["loss_over_choices"] = loss_over_choices[j]
|
||||
@ -445,7 +468,13 @@ class HuggingFaceModel(BaseModel):
|
||||
|
||||
# Set output_scores=True to get prediction scores.
|
||||
outputs = self.model.generate(
|
||||
**encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs
|
||||
**encoded_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
do_sample=False,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# We only need to decode predicted tokens.
|
||||
@ -540,10 +569,13 @@ class HuggingFaceCausalLM(HuggingFaceModel):
|
||||
prompt_template: The model's prompt template.
|
||||
batch_size: Batch size for inference.
|
||||
logger: Logger for the model.
|
||||
shard_config: Shard config for tensor parallel.
|
||||
|
||||
"""
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
|
||||
def _load_model(
|
||||
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
|
||||
):
|
||||
"""
|
||||
Load model.
|
||||
|
||||
@ -551,17 +583,29 @@ class HuggingFaceCausalLM(HuggingFaceModel):
|
||||
path: The path to the model.
|
||||
model_kwargs: Keyword arguments for the model.
|
||||
peft_path: The path to the peft model.
|
||||
shard_config: Shard config for tensor parallel.
|
||||
|
||||
"""
|
||||
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
|
||||
if "torch_dtype" in model_kwargs:
|
||||
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
|
||||
|
||||
if "config" in model_kwargs:
|
||||
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
|
||||
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
if shard_config is not None:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
shard_former = ShardFormer(shard_config)
|
||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||
self.model.to(torch.cuda.current_device())
|
||||
|
||||
if peft_path is not None:
|
||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||
else:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
|
||||
self.model.eval()
|
||||
|
@ -9,6 +9,7 @@ class SeparatorStyle(Enum):
|
||||
ADD_BOS_EOS_TOKEN = auto()
|
||||
ALPACA = auto()
|
||||
PLAIN = auto()
|
||||
YAYI = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -48,6 +49,14 @@ class Conversation:
|
||||
else:
|
||||
ret += ""
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.YAYI:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ":\n" + message + self.sep
|
||||
else:
|
||||
ret += role + ":\n"
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
@ -71,6 +80,8 @@ class Conversation:
|
||||
prompt_with_target.append(prompt + target_answer)
|
||||
elif self.sep_style == SeparatorStyle.PLAIN:
|
||||
prompt_with_target.append(prompt + target_answer)
|
||||
elif self.sep_style == SeparatorStyle.YAYI:
|
||||
prompt_with_target.append(prompt + target_answer)
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
@ -126,13 +137,11 @@ def get_few_shot_prefix(
|
||||
Few shot prompt prefix.
|
||||
"""
|
||||
|
||||
if language == "English":
|
||||
few_shot_prefix = f"The following are answers for questions in an exam.\n\n"
|
||||
elif language == "Chinese":
|
||||
few_shot_prefix = f"以下是考试中各个问题的答案。\n\n"
|
||||
# First few shot data is something like "The following are questions about xxx".
|
||||
few_shot_prefix = few_shot_data[0] + "\n\n"
|
||||
|
||||
output = None
|
||||
for i in range(len(few_shot_data)):
|
||||
for i in range(1, len(few_shot_data)):
|
||||
few_shot_prefix = few_shot_prefix + few_shot_data[i] + "\n\n"
|
||||
|
||||
if len(tokenizer([few_shot_prefix]).input_ids[0]) <= max_tokens:
|
||||
@ -189,9 +198,10 @@ def get_batch_prompt(
|
||||
conv.append_message(conv.roles[1], None)
|
||||
else:
|
||||
if not isinstance(b["instruction"], list):
|
||||
query_text = (
|
||||
b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
|
||||
)
|
||||
if b["instruction"] != "":
|
||||
query_text = b["instruction"] + "\n\n" + b["input"] if b["input"] != "" else b["instruction"]
|
||||
else:
|
||||
query_text = b["input"]
|
||||
conv.append_message(conv.roles[0], query_text)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
else:
|
||||
@ -244,4 +254,13 @@ conv_plain = Conversation(
|
||||
sep="",
|
||||
)
|
||||
|
||||
prompt_templates = {"coati": conv_coati, "alpaca": conv_alpaca, "plain": conv_plain}
|
||||
conv_yayi = Conversation(
|
||||
system="<|System|>:\nYou are a helpful, respectful and honest assistant named YaYi developed by Beijing Wenge Technology Co.,Ltd. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n\n",
|
||||
roles=("<|Human|>", "<|YaYi|>"),
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.YAYI,
|
||||
sep="\n\n",
|
||||
)
|
||||
|
||||
prompt_templates = {"coati": conv_coati, "alpaca": conv_alpaca, "plain": conv_plain, "yayi": conv_yayi}
|
||||
|
@ -8,17 +8,19 @@ import torch.distributed as dist
|
||||
from colossal_eval import dataset, models, utils
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer import ShardConfig
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
|
||||
def rm_and_merge(dp_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
|
||||
"""
|
||||
Remove inference result per rank and merge them into one file.
|
||||
|
||||
Args:
|
||||
world_size: Number of processes for inference.
|
||||
dp_size: Number of groups for data parallel.
|
||||
save_path: The folder for storing inference results.
|
||||
model_names: Names of models for inference.
|
||||
dataset_names: Names of dataset for inference.
|
||||
@ -32,9 +34,9 @@ def rm_and_merge(world_size: int, save_path: str, model_names: List[str], datase
|
||||
all_answers[category] = {"data": []}
|
||||
answers = {"data": []}
|
||||
|
||||
for r in range(world_size):
|
||||
for r in range(dp_size):
|
||||
directory = os.path.join(
|
||||
save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
|
||||
save_path, model_name, f"{dataset_name}_{category}_inference_results_dp_rank{r}.json"
|
||||
)
|
||||
if not os.path.exists(directory):
|
||||
raise Exception(
|
||||
@ -45,10 +47,10 @@ def rm_and_merge(world_size: int, save_path: str, model_names: List[str], datase
|
||||
answers["data"].extend(rank_answers["data"])
|
||||
answers["inference_kwargs"] = rank_answers["inference_kwargs"]
|
||||
|
||||
for r in range(world_size):
|
||||
for r in range(dp_size):
|
||||
try:
|
||||
directory = os.path.join(
|
||||
save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
|
||||
save_path, model_name, f"{dataset_name}_{category}_inference_results_dp_rank{r}.json"
|
||||
)
|
||||
os.remove(directory)
|
||||
except Exception as e:
|
||||
@ -66,7 +68,34 @@ def rm_and_merge(world_size: int, save_path: str, model_names: List[str], datase
|
||||
def main(args):
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
rank = dist.get_rank()
|
||||
DP_AXIS = 0
|
||||
TP_AXIS = 1
|
||||
|
||||
dp_size = world_size // args.tp_size
|
||||
|
||||
if rank == 0:
|
||||
logger.info("Setting TP and DP...")
|
||||
logger.info(f"TP size: {args.tp_size}, DP size: {dp_size}")
|
||||
|
||||
if world_size % args.tp_size != 0:
|
||||
raise Exception(
|
||||
f"TP size is {args.tp_size} while world size is {world_size}! Please make sure world size is a multiple of TP size!"
|
||||
)
|
||||
|
||||
pg_mesh = ProcessGroupMesh(dp_size, args.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
coordinates = pg_mesh._coord
|
||||
dp_rank = coordinates[DP_AXIS]
|
||||
tp_rank = coordinates[TP_AXIS]
|
||||
|
||||
shard_config = (
|
||||
ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=args.tp_size > 1)
|
||||
if args.tp_size > 1
|
||||
else None
|
||||
)
|
||||
|
||||
inference_data = {}
|
||||
debug_args = {}
|
||||
@ -84,6 +113,9 @@ def main(args):
|
||||
dataset_name = dataset_parameter["name"]
|
||||
debug_args[dataset_name] = dataset_parameter["debug"]
|
||||
few_shot_args[dataset_name] = dataset_parameter["few_shot"]
|
||||
forward_only = dataset_parameter.get("forward_only", False)
|
||||
load_train = dataset_parameter.get("load_train", False)
|
||||
load_reference = dataset_parameter.get("load_reference", False)
|
||||
|
||||
if not args.load_dataset:
|
||||
if os.path.exists(save_path):
|
||||
@ -100,7 +132,7 @@ def main(args):
|
||||
if not issubclass(dataset_class, dataset.BaseDataset):
|
||||
raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")
|
||||
|
||||
dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
|
||||
dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"], forward_only, load_train, load_reference)
|
||||
|
||||
dataset_.save(save_path)
|
||||
|
||||
@ -112,12 +144,28 @@ def main(args):
|
||||
|
||||
inference_data[dataset_name] = dataset_.dataset["test"]
|
||||
|
||||
if load_train and "train" in dataset_.dataset:
|
||||
new_dataset_name = f"{dataset_name}_train"
|
||||
debug_args[new_dataset_name] = dataset_parameter["debug"]
|
||||
few_shot_args[new_dataset_name] = dataset_parameter["few_shot"]
|
||||
inference_data[new_dataset_name] = dataset_.dataset["train"]
|
||||
|
||||
if load_reference and "reference" in dataset_.dataset:
|
||||
new_dataset_name = f"{dataset_name}_reference"
|
||||
debug_args[new_dataset_name] = dataset_parameter["debug"]
|
||||
few_shot_args[new_dataset_name] = dataset_parameter["few_shot"]
|
||||
inference_data[new_dataset_name] = dataset_.dataset["reference"]
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"Dataset for inference are: {list(inference_data.keys())}")
|
||||
|
||||
for model_parameter in model_parameters:
|
||||
model_name = model_parameter["name"]
|
||||
model_class = eval(f"models.{model_parameter['model_class']}")
|
||||
paramerters = model_parameter["parameters"]
|
||||
paramerters.update({"logger": logger})
|
||||
paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
|
||||
paramerters.update({"shard_config": shard_config})
|
||||
|
||||
model_ = model_class(**paramerters)
|
||||
if not issubclass(model_class, models.BaseModel):
|
||||
@ -133,19 +181,21 @@ def main(args):
|
||||
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
|
||||
|
||||
answers_to_dump = copy.deepcopy(category_data)
|
||||
partition_size = len(category_data["data"]) // world_size
|
||||
redundant = len(category_data["data"]) % world_size
|
||||
partition_size = len(category_data["data"]) // dp_size
|
||||
redundant = len(category_data["data"]) % dp_size
|
||||
|
||||
# Ensure that the amount of data for inference is as consistent as possible across different processes.
|
||||
lengths = [partition_size for _ in range(world_size)]
|
||||
lengths = [partition_size for _ in range(dp_size)]
|
||||
for j in range(redundant):
|
||||
lengths[(j + start) % world_size] += 1
|
||||
lengths[(j + start) % dp_size] += 1
|
||||
|
||||
start = (start + redundant) % world_size
|
||||
start = (start + redundant) % dp_size
|
||||
|
||||
for turn in range(num_turn):
|
||||
if turn == 0:
|
||||
questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
|
||||
questions = category_data["data"][
|
||||
sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank]
|
||||
]
|
||||
else:
|
||||
questions = prev_questions
|
||||
|
||||
@ -156,14 +206,15 @@ def main(args):
|
||||
|
||||
answers_to_dump["data"] = answers_per_rank
|
||||
|
||||
utils.jdump(
|
||||
answers_to_dump,
|
||||
os.path.join(
|
||||
args.inference_save_path,
|
||||
model_name,
|
||||
f"{dataset_name}_{category}_inference_results_rank{rank}.json",
|
||||
),
|
||||
)
|
||||
if tp_rank == 0:
|
||||
utils.jdump(
|
||||
answers_to_dump,
|
||||
os.path.join(
|
||||
args.inference_save_path,
|
||||
model_name,
|
||||
f"{dataset_name}_{category}_inference_results_dp_rank{dp_rank}.json",
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
|
||||
|
||||
@ -174,7 +225,7 @@ def main(args):
|
||||
if rank == 0:
|
||||
model_names = [model_parameter["name"] for model_parameter in model_parameters]
|
||||
dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
|
||||
rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names)
|
||||
rm_and_merge(dp_size, args.inference_save_path, model_names, dataset_names)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -182,6 +233,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
|
||||
parser.add_argument("--load_dataset", default=False, action="store_true")
|
||||
parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results")
|
||||
parser.add_argument("--tp_size", type=int, default=1, help="tensor parallel size, used for large model inference")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
@ -1,4 +1,5 @@
|
||||
torchrun --nproc_per_node=1 inference.py \
|
||||
--config "path to config file" \
|
||||
--load_dataset \
|
||||
--tp_size 1 \
|
||||
--inference_save_path "path to save inference results"
|
||||
|
@ -8,17 +8,19 @@ import torch.distributed as dist
|
||||
from colossal_eval import dataset, models, utils
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer import ShardConfig
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
|
||||
def rm_and_merge(dp_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
|
||||
"""
|
||||
Remove inference result per rank and merge them into one file.
|
||||
|
||||
Args:
|
||||
world_size: Number of processes for inference.
|
||||
dp_size: Number of groups for data parallel.
|
||||
save_path: The folder for storing inference results.
|
||||
model_names: Names of models for inference.
|
||||
dataset_names: Names of dataset for inference.
|
||||
@ -32,9 +34,9 @@ def rm_and_merge(world_size: int, save_path: str, model_names: List[str], datase
|
||||
all_answers[category] = {"data": []}
|
||||
answers = {"data": []}
|
||||
|
||||
for r in range(world_size):
|
||||
for r in range(dp_size):
|
||||
directory = os.path.join(
|
||||
save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
|
||||
save_path, model_name, f"{dataset_name}_{category}_inference_results_dp_rank{r}.json"
|
||||
)
|
||||
if not os.path.exists(directory):
|
||||
raise Exception(
|
||||
@ -45,10 +47,10 @@ def rm_and_merge(world_size: int, save_path: str, model_names: List[str], datase
|
||||
answers["data"].extend(rank_answers["data"])
|
||||
answers["inference_kwargs"] = rank_answers["inference_kwargs"]
|
||||
|
||||
for r in range(world_size):
|
||||
for r in range(dp_size):
|
||||
try:
|
||||
directory = os.path.join(
|
||||
save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
|
||||
save_path, model_name, f"{dataset_name}_{category}_inference_results_dp_rank{r}.json"
|
||||
)
|
||||
os.remove(directory)
|
||||
except Exception as e:
|
||||
@ -66,11 +68,39 @@ def rm_and_merge(world_size: int, save_path: str, model_names: List[str], datase
|
||||
def main(args):
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
rank = dist.get_rank()
|
||||
DP_AXIS = 0
|
||||
TP_AXIS = 1
|
||||
|
||||
dp_size = world_size // args.tp_size
|
||||
|
||||
if rank == 0:
|
||||
logger.info("Setting TP and DP...")
|
||||
logger.info(f"TP size: {args.tp_size}, DP size: {dp_size}")
|
||||
|
||||
if world_size % args.tp_size != 0:
|
||||
raise Exception(
|
||||
f"TP size is {args.tp_size} while world size is {world_size}! Please make sure world size is a multiple of TP size!"
|
||||
)
|
||||
|
||||
pg_mesh = ProcessGroupMesh(dp_size, args.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
coordinates = pg_mesh._coord
|
||||
dp_rank = coordinates[DP_AXIS]
|
||||
tp_rank = coordinates[TP_AXIS]
|
||||
|
||||
shard_config = (
|
||||
ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=args.tp_size > 1)
|
||||
if args.tp_size > 1
|
||||
else None
|
||||
)
|
||||
|
||||
inference_data = {}
|
||||
debug_args = {}
|
||||
few_shot_args = {}
|
||||
multiturn_args = {}
|
||||
|
||||
config = utils.jload(args.config)
|
||||
|
||||
@ -83,6 +113,9 @@ def main(args):
|
||||
dataset_name = dataset_parameter["name"]
|
||||
debug_args[dataset_name] = dataset_parameter["debug"]
|
||||
few_shot_args[dataset_name] = dataset_parameter["few_shot"]
|
||||
forward_only = dataset_parameter.get("forward_only", False)
|
||||
load_train = dataset_parameter.get("load_train", False)
|
||||
load_reference = dataset_parameter.get("load_reference", False)
|
||||
|
||||
if not args.load_dataset:
|
||||
if os.path.exists(save_path):
|
||||
@ -99,17 +132,40 @@ def main(args):
|
||||
if not issubclass(dataset_class, dataset.BaseDataset):
|
||||
raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")
|
||||
|
||||
dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
|
||||
dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"], forward_only, load_train, load_reference)
|
||||
|
||||
dataset_.save(save_path)
|
||||
|
||||
if hasattr(dataset_, "multiturn") and dataset_.multiturn:
|
||||
multiturn_args[dataset_name] = True
|
||||
logger.info(f"{dataset_parameter['dataset_class']} is a multiturn dataset.")
|
||||
else:
|
||||
multiturn_args[dataset_name] = False
|
||||
|
||||
inference_data[dataset_name] = dataset_.dataset["test"]
|
||||
|
||||
if load_train and "train" in dataset_.dataset:
|
||||
new_dataset_name = f"{dataset_name}_train"
|
||||
debug_args[new_dataset_name] = dataset_parameter["debug"]
|
||||
few_shot_args[new_dataset_name] = dataset_parameter["few_shot"]
|
||||
inference_data[new_dataset_name] = dataset_.dataset["train"]
|
||||
|
||||
if load_reference and "reference" in dataset_.dataset:
|
||||
new_dataset_name = f"{dataset_name}_reference"
|
||||
debug_args[new_dataset_name] = dataset_parameter["debug"]
|
||||
few_shot_args[new_dataset_name] = dataset_parameter["few_shot"]
|
||||
inference_data[new_dataset_name] = dataset_.dataset["reference"]
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"Dataset for inference are: {list(inference_data.keys())}")
|
||||
|
||||
for model_parameter in model_parameters:
|
||||
model_name = model_parameter["name"]
|
||||
model_class = eval(f"models.{model_parameter['model_class']}")
|
||||
paramerters = model_parameter["parameters"]
|
||||
paramerters.update({"logger": logger})
|
||||
paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
|
||||
paramerters.update({"shard_config": shard_config})
|
||||
|
||||
model_ = model_class(**paramerters)
|
||||
if not issubclass(model_class, models.BaseModel):
|
||||
@ -117,37 +173,48 @@ def main(args):
|
||||
|
||||
for dataset_name, split_data in inference_data.items():
|
||||
start = 0
|
||||
prev_questions = None
|
||||
for category, category_data in split_data.items():
|
||||
num_turn = category_data["inference_kwargs"].get("turns", 1)
|
||||
|
||||
if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
|
||||
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
|
||||
|
||||
answers_to_dump = copy.deepcopy(category_data)
|
||||
partition_size = len(category_data["data"]) // world_size
|
||||
redundant = len(category_data["data"]) % world_size
|
||||
partition_size = len(category_data["data"]) // dp_size
|
||||
redundant = len(category_data["data"]) % dp_size
|
||||
|
||||
# Ensure that the amount of data for inference is as consistent as possible across different processes.
|
||||
lengths = [partition_size for _ in range(world_size)]
|
||||
lengths = [partition_size for _ in range(dp_size)]
|
||||
for j in range(redundant):
|
||||
lengths[(j + start) % world_size] += 1
|
||||
lengths[(j + start) % dp_size] += 1
|
||||
|
||||
start = (start + redundant) % world_size
|
||||
start = (start + redundant) % dp_size
|
||||
|
||||
questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
|
||||
for turn in range(num_turn):
|
||||
if turn == 0:
|
||||
questions = category_data["data"][
|
||||
sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank]
|
||||
]
|
||||
else:
|
||||
questions = prev_questions
|
||||
|
||||
answers_per_rank = model_.inference(
|
||||
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
|
||||
)
|
||||
answers_per_rank = model_.inference(
|
||||
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
|
||||
)
|
||||
prev_questions = answers_per_rank
|
||||
|
||||
answers_to_dump["data"] = answers_per_rank
|
||||
|
||||
utils.jdump(
|
||||
answers_to_dump,
|
||||
os.path.join(
|
||||
args.inference_save_path,
|
||||
model_name,
|
||||
f"{dataset_name}_{category}_inference_results_rank{rank}.json",
|
||||
),
|
||||
)
|
||||
if tp_rank == 0:
|
||||
utils.jdump(
|
||||
answers_to_dump,
|
||||
os.path.join(
|
||||
args.inference_save_path,
|
||||
model_name,
|
||||
f"{dataset_name}_{category}_inference_results_dp_rank{dp_rank}.json",
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
|
||||
|
||||
@ -158,7 +225,7 @@ def main(args):
|
||||
if rank == 0:
|
||||
model_names = [model_parameter["name"] for model_parameter in model_parameters]
|
||||
dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
|
||||
rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names)
|
||||
rm_and_merge(dp_size, args.inference_save_path, model_names, dataset_names)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -166,6 +233,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
|
||||
parser.add_argument("--load_dataset", default=False, action="store_true")
|
||||
parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results")
|
||||
parser.add_argument("--tp_size", type=int, default=1, help="tensor parallel size, used for large model inference")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
@ -1,4 +1,5 @@
|
||||
torchrun --nproc_per_node=1 inference.py \
|
||||
--config "path to config file" \
|
||||
--load_dataset \
|
||||
--tp_size 1 \
|
||||
--inference_save_path "path to save inference results"
|
||||
|
@ -1,5 +1,5 @@
|
||||
transformers>=4.32.0
|
||||
colossalai>=0.3.1
|
||||
colossalai>=0.3.4
|
||||
peft
|
||||
tabulate
|
||||
jieba
|
||||
|
@ -19,7 +19,7 @@ setup(
|
||||
long_description=fetch_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
license="Apache Software License 2.0",
|
||||
url="https://github.com/hpcaitech/LLM-Evaluation",
|
||||
url="https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval",
|
||||
install_requires=fetch_requirements("requirements.txt"),
|
||||
python_requires=">=3.6",
|
||||
classifiers=[
|
||||
|
Loading…
Reference in New Issue
Block a user