[ColossalEval] Support GSM, Data Leakage Evaluation and Tensor Parallel (#5169)

* Support GSM, Data Leakage Evaluation and Tensor Parallel

* remove redundant code and update inference.py in examples/gpt_evaluation

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
This commit is contained in:
Yuanchen
2023-12-12 14:47:35 +08:00
committed by GitHub
parent b07a6f4e27
commit cefdc32615
19 changed files with 578 additions and 100 deletions

View File

@@ -1,6 +1,7 @@
# Code adapted from https://github.com/THUDM/LongBench/blob/main/metrics.py
# Code adapted from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
# Code adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/evaluation.py
# https://github.com/SkyworkAI/Skywork/blob/main/eval/eval_gsm8k.py
import difflib
import re
@@ -11,6 +12,11 @@ import jieba
from fuzzywuzzy import fuzz
from rouge import Rouge
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"
ans_re1 = re.compile(r"(\-?[0-9][0-9\.\,]*)")
ans_re2 = re.compile(r"=\s*(\$?-?[0-9][0-9\.\,]*)")
metrics4subcategory = {
"pretrain": {
"perplexity": ["ALL"],
@@ -189,6 +195,10 @@ metrics4subcategory = {
"cvalues": {"first_token_accuracy": ["ALL"]},
"safetybench_zh": {"first_token_accuracy": ["ALL"]},
"safetybench_en": {"first_token_accuracy": ["ALL"]},
"gsm": {
"loss_over_all_tokens": ["ALL"],
"gsm_accuracy": ["ALL"],
},
}
@@ -639,3 +649,61 @@ def f1_zh_score(prediction, reference, **kwargs):
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
return _f1_score(prediction_tokens, ground_truth_tokens)
def extract_answer_hf(completion):
match = ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return eval(match_str)
else:
return INVALID_ANS
def get_match_str(match, idx):
match_str = match[idx]
match_str = match_str.replace(",", "")
if match_str.endswith("."):
match_str = match_str[:-1]
if match_str.endswith(".00"):
match_str = match_str[:-3]
if match_str.endswith(".0"):
match_str = match_str[:-2]
return match_str
def extract_answer(completion):
match1 = re.findall(ans_re1, completion)
match2 = re.findall(ans_re2, completion)
ans = []
if match1:
match_str1 = get_match_str(match1, -1)
ans.append(match_str1)
if match2:
match_str2 = get_match_str(match2, -1).replace("$", "")
ans.append(match_str2)
answer = INVALID_ANS
try:
if len(ans) > 0:
answer = eval(ans[-1])
except Exception as e:
print(e)
return answer
return answer
def is_correct(completion, answer):
gold = extract_answer_hf(answer)
assert gold != INVALID_ANS, "No ground truth answer found in the document."
completion = completion.split("answer is")[-1]
return extract_answer(completion) == gold
def gsm_accuracy(prediction, reference, **kwargs):
prediction = prediction.split("\n\n\n")[0]
prediction = prediction.split("\n\n")[0]
prediction = prediction.split("Question:")[0]
return 1.0 if is_correct(prediction, reference) else 0.0