Improve logic for selecting metrics (#5196)

Co-authored-by: Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
Yuanchen
2023-12-22 14:52:50 +08:00
committed by GitHub
parent 4fa689fca1
commit eae01b6740
4 changed files with 62 additions and 23 deletions

View File

@@ -15,7 +15,13 @@ from colossalai.shardformer import ShardConfig
logger = get_dist_logger()
def rm_and_merge(dp_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
def rm_and_merge(
dp_size: int,
save_path: str,
model_names: List[str],
dataset_names: Dict[str, List],
dataset_classes: Dict[str, List],
) -> None:
"""
Remove inference result per rank and merge them into one file.
@@ -24,11 +30,15 @@ def rm_and_merge(dp_size: int, save_path: str, model_names: List[str], dataset_n
save_path: The folder for storing inference results.
model_names: Names of models for inference.
dataset_names: Names of dataset for inference.
dataset_classes: Dataset class for different inference results. We need to save dataset class to smooth the evaluation process.
"""
for model_name in model_names:
for dataset_name, categories in dataset_names.items():
all_answers_with_dataset_class = {}
all_answers_with_dataset_class["dataset_class"] = dataset_classes[dataset_name]
all_answers = {}
for category in categories:
all_answers[category] = {"data": []}
@@ -58,8 +68,13 @@ def rm_and_merge(dp_size: int, save_path: str, model_names: List[str], dataset_n
all_answers[category] = answers
all_answers_with_dataset_class["inference_results"] = all_answers
logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.")
utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"))
utils.jdump(
all_answers_with_dataset_class,
os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"),
)
logger.info(f"Save inference results of model {model_name} for all dataset.")
logger.info(f"Save inference results of all models for all dataset.")
@@ -98,6 +113,7 @@ def main(args):
)
inference_data = {}
dataset_classes = {}
debug_args = {}
few_shot_args = {}
multiturn_args = {}
@@ -128,6 +144,7 @@ def main(args):
continue
dataset_classes[dataset_name] = dataset_parameter["dataset_class"]
dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}")
if not issubclass(dataset_class, dataset.BaseDataset):
raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")
@@ -149,12 +166,14 @@ def main(args):
debug_args[new_dataset_name] = dataset_parameter["debug"]
few_shot_args[new_dataset_name] = dataset_parameter["few_shot"]
inference_data[new_dataset_name] = dataset_.dataset["train"]
dataset_classes[new_dataset_name] = dataset_parameter["dataset_class"]
if load_reference and "reference" in dataset_.dataset:
new_dataset_name = f"{dataset_name}_reference"
debug_args[new_dataset_name] = dataset_parameter["debug"]
few_shot_args[new_dataset_name] = dataset_parameter["few_shot"]
inference_data[new_dataset_name] = dataset_.dataset["reference"]
dataset_classes[new_dataset_name] = dataset_parameter["dataset_class"]
if rank == 0:
logger.info(f"Dataset for inference are: {list(inference_data.keys())}")
@@ -225,7 +244,7 @@ def main(args):
if rank == 0:
model_names = [model_parameter["name"] for model_parameter in model_parameters]
dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
rm_and_merge(dp_size, args.inference_save_path, model_names, dataset_names)
rm_and_merge(dp_size, args.inference_save_path, model_names, dataset_names, dataset_classes)
if __name__ == "__main__":