mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-23 20:43:19 +00:00
support evaluation for english (#3880)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
@@ -1,6 +1,15 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from typing import Dict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import tqdm
|
||||
from zhon import hanzi
|
||||
|
||||
|
||||
def _make_w_io_base(f, mode: str):
|
||||
@@ -29,7 +38,7 @@ def jdump(obj, f, mode="w", indent=4, default=str):
|
||||
"""
|
||||
f = _make_w_io_base(f, mode)
|
||||
if isinstance(obj, (dict, list)):
|
||||
json.dump(obj, f, indent=indent, default=default)
|
||||
json.dump(obj, f, indent=indent, default=default, ensure_ascii=False)
|
||||
elif isinstance(obj, str):
|
||||
f.write(obj)
|
||||
else:
|
||||
@@ -61,3 +70,149 @@ def get_data_per_category(data, categories):
|
||||
data_per_category[category].append(item)
|
||||
|
||||
return data_per_category
|
||||
|
||||
|
||||
def remove_articles(text: str) -> str:
|
||||
"""
|
||||
Remove articles "a, an, the" in the given text.
|
||||
It is used in evaluation of automatic metrics.
|
||||
|
||||
"""
|
||||
|
||||
pattern = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
||||
return re.sub(pattern, " ", text)
|
||||
|
||||
|
||||
def remove_punctuations(text: str) -> str:
|
||||
"""
|
||||
Remove punctuations in the given text.
|
||||
It is used in evaluation of automatic metrics.
|
||||
|
||||
"""
|
||||
|
||||
punctuation = string.punctuation + hanzi.punctuation
|
||||
punctuation = set([char for char in punctuation])
|
||||
punctuation.difference_update(set("!@#$%&()<>?|,.\"'"))
|
||||
|
||||
out = []
|
||||
for char in text:
|
||||
if char in punctuation:
|
||||
continue
|
||||
else:
|
||||
out.append(char)
|
||||
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def remove_redundant_space(text: str) -> str:
|
||||
"""
|
||||
Remove redundant spaces in the given text.
|
||||
It is used in evaluation of automatic metrics.
|
||||
|
||||
"""
|
||||
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def preprocessing_text(text: str) -> str:
|
||||
"""
|
||||
Preprocess the given text.
|
||||
It is used in evaluation of automatic metrics.
|
||||
|
||||
"""
|
||||
|
||||
return remove_redundant_space(remove_articles(remove_punctuations(text.lower())))
|
||||
|
||||
|
||||
def save_automatic_results(model_name: str, automatic_metric_stats: Dict[str, Dict], save_path: str) -> None:
|
||||
"""
|
||||
Save automatic evaluation results of different categories for one model.
|
||||
|
||||
"""
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
automatic_df = pd.DataFrame(automatic_metric_stats)
|
||||
automatic_df.to_csv(os.path.join(save_path, f"{model_name}_results.csv"), index=True)
|
||||
|
||||
|
||||
def read_automatic_results(results_path: str, file_name: str) -> Dict[str, Dict]:
|
||||
"""
|
||||
Read a csv file and return a dictionary which stores scores per metric.
|
||||
|
||||
"""
|
||||
|
||||
results = pd.read_csv(os.path.join(results_path, file_name), index_col=0)
|
||||
|
||||
results_dict = {metric: {} for metric in list(results.index)}
|
||||
for i, metric in enumerate(results_dict.keys()):
|
||||
for j, category in enumerate(list(results.columns)):
|
||||
if pd.isnull(results.iloc[i][j]):
|
||||
continue
|
||||
results_dict[metric][category] = results.iloc[i][j]
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def analyze_automatic_results(results_path: str, save_path: str) -> None:
|
||||
"""
|
||||
Analyze and visualize all csv files in the given folder.
|
||||
|
||||
"""
|
||||
|
||||
if not os.path.exists(results_path):
|
||||
raise Exception(f'The given directory "{results_path}" doesn\'t exist! No results found!')
|
||||
|
||||
all_statistics = {}
|
||||
|
||||
for file_name in os.listdir(results_path):
|
||||
if file_name.endswith("_results.csv"):
|
||||
model_name = file_name.split("_results.csv")[0]
|
||||
all_statistics[model_name] = read_automatic_results(results_path, file_name)
|
||||
|
||||
if len(list(all_statistics.keys())) == 0:
|
||||
raise Exception(f'There are no csv files in the given directory "{results_path}"!')
|
||||
|
||||
frame_all = {"model": [], "category": [], "metric": [], "score": []}
|
||||
frame_per_metric = {}
|
||||
for model_name, model_statistics in all_statistics.items():
|
||||
for metric, metric_statistics in model_statistics.items():
|
||||
if frame_per_metric.get(metric) is None:
|
||||
frame_per_metric[metric] = {"model": [], "category": [], "score": []}
|
||||
|
||||
for category, category_score in metric_statistics.items():
|
||||
frame_all["model"].append(model_name)
|
||||
frame_all["category"].append(category)
|
||||
frame_all["metric"].append(metric)
|
||||
frame_all["score"].append(category_score)
|
||||
|
||||
frame_per_metric[metric]["model"].append(model_name)
|
||||
frame_per_metric[metric]["category"].append(category)
|
||||
frame_per_metric[metric]["score"].append(category_score)
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
frame_all = pd.DataFrame(frame_all)
|
||||
frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv"))
|
||||
|
||||
for metric in tqdm.tqdm(
|
||||
frame_per_metric.keys(),
|
||||
desc=f"metric: ",
|
||||
total=len(frame_per_metric.keys()),
|
||||
):
|
||||
data = pd.DataFrame(frame_per_metric[metric])
|
||||
|
||||
sns.set()
|
||||
fig = plt.figure(figsize=(16, 10))
|
||||
|
||||
fig = sns.barplot(x="category", y="score", hue="model", data=data, dodge=True)
|
||||
fig.set_title(f"Comparison between Different Models for Metric {metric.title()}")
|
||||
plt.xlabel("Evaluation Category")
|
||||
plt.ylabel("Score")
|
||||
|
||||
figure = fig.get_figure()
|
||||
figure.savefig(os.path.join(save_path, f"{metric}.png"), dpi=400)
|
||||
|
||||
plt.close()
|
||||
|
||||
Reference in New Issue
Block a user