mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[ColossalEval] Support GSM, Data Leakage Evaluation and Tensor Parallel (#5169)
* Support GSM, Data Leakage Evaluation and Tensor Parallel * remove redundant code and update inference.py in examples/gpt_evaluation --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Code adapted from https://github.com/THUDM/LongBench/blob/main/metrics.py
|
||||
# Code adapted from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
|
||||
# Code adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/evaluation.py
|
||||
# https://github.com/SkyworkAI/Skywork/blob/main/eval/eval_gsm8k.py
|
||||
|
||||
import difflib
|
||||
import re
|
||||
@@ -11,6 +12,11 @@ import jieba
|
||||
from fuzzywuzzy import fuzz
|
||||
from rouge import Rouge
|
||||
|
||||
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||
INVALID_ANS = "[invalid]"
|
||||
ans_re1 = re.compile(r"(\-?[0-9][0-9\.\,]*)")
|
||||
ans_re2 = re.compile(r"=\s*(\$?-?[0-9][0-9\.\,]*)")
|
||||
|
||||
metrics4subcategory = {
|
||||
"pretrain": {
|
||||
"perplexity": ["ALL"],
|
||||
@@ -189,6 +195,10 @@ metrics4subcategory = {
|
||||
"cvalues": {"first_token_accuracy": ["ALL"]},
|
||||
"safetybench_zh": {"first_token_accuracy": ["ALL"]},
|
||||
"safetybench_en": {"first_token_accuracy": ["ALL"]},
|
||||
"gsm": {
|
||||
"loss_over_all_tokens": ["ALL"],
|
||||
"gsm_accuracy": ["ALL"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -639,3 +649,61 @@ def f1_zh_score(prediction, reference, **kwargs):
|
||||
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
|
||||
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
|
||||
return _f1_score(prediction_tokens, ground_truth_tokens)
|
||||
|
||||
|
||||
def extract_answer_hf(completion):
|
||||
match = ANS_RE.search(completion)
|
||||
if match:
|
||||
match_str = match.group(1).strip()
|
||||
match_str = match_str.replace(",", "")
|
||||
return eval(match_str)
|
||||
else:
|
||||
return INVALID_ANS
|
||||
|
||||
|
||||
def get_match_str(match, idx):
|
||||
match_str = match[idx]
|
||||
match_str = match_str.replace(",", "")
|
||||
if match_str.endswith("."):
|
||||
match_str = match_str[:-1]
|
||||
if match_str.endswith(".00"):
|
||||
match_str = match_str[:-3]
|
||||
if match_str.endswith(".0"):
|
||||
match_str = match_str[:-2]
|
||||
return match_str
|
||||
|
||||
|
||||
def extract_answer(completion):
|
||||
match1 = re.findall(ans_re1, completion)
|
||||
match2 = re.findall(ans_re2, completion)
|
||||
ans = []
|
||||
if match1:
|
||||
match_str1 = get_match_str(match1, -1)
|
||||
ans.append(match_str1)
|
||||
if match2:
|
||||
match_str2 = get_match_str(match2, -1).replace("$", "")
|
||||
ans.append(match_str2)
|
||||
|
||||
answer = INVALID_ANS
|
||||
try:
|
||||
if len(ans) > 0:
|
||||
answer = eval(ans[-1])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return answer
|
||||
return answer
|
||||
|
||||
|
||||
def is_correct(completion, answer):
|
||||
gold = extract_answer_hf(answer)
|
||||
assert gold != INVALID_ANS, "No ground truth answer found in the document."
|
||||
completion = completion.split("answer is")[-1]
|
||||
return extract_answer(completion) == gold
|
||||
|
||||
|
||||
def gsm_accuracy(prediction, reference, **kwargs):
|
||||
prediction = prediction.split("\n\n\n")[0]
|
||||
prediction = prediction.split("\n\n")[0]
|
||||
prediction = prediction.split("Question:")[0]
|
||||
|
||||
return 1.0 if is_correct(prediction, reference) else 0.0
|
||||
|
Reference in New Issue
Block a user