mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
Improve logic for selecting metrics (#5196)
Co-authored-by: Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
@@ -15,7 +15,13 @@ from colossalai.shardformer import ShardConfig
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def rm_and_merge(dp_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
|
||||
def rm_and_merge(
|
||||
dp_size: int,
|
||||
save_path: str,
|
||||
model_names: List[str],
|
||||
dataset_names: Dict[str, List],
|
||||
dataset_classes: Dict[str, List],
|
||||
) -> None:
|
||||
"""
|
||||
Remove inference result per rank and merge them into one file.
|
||||
|
||||
@@ -24,11 +30,15 @@ def rm_and_merge(dp_size: int, save_path: str, model_names: List[str], dataset_n
|
||||
save_path: The folder for storing inference results.
|
||||
model_names: Names of models for inference.
|
||||
dataset_names: Names of dataset for inference.
|
||||
dataset_classes: Dataset class for different inference results. We need to save dataset class to smooth the evaluation process.
|
||||
|
||||
"""
|
||||
|
||||
for model_name in model_names:
|
||||
for dataset_name, categories in dataset_names.items():
|
||||
all_answers_with_dataset_class = {}
|
||||
all_answers_with_dataset_class["dataset_class"] = dataset_classes[dataset_name]
|
||||
|
||||
all_answers = {}
|
||||
for category in categories:
|
||||
all_answers[category] = {"data": []}
|
||||
@@ -58,8 +68,13 @@ def rm_and_merge(dp_size: int, save_path: str, model_names: List[str], dataset_n
|
||||
|
||||
all_answers[category] = answers
|
||||
|
||||
all_answers_with_dataset_class["inference_results"] = all_answers
|
||||
|
||||
logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.")
|
||||
utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"))
|
||||
utils.jdump(
|
||||
all_answers_with_dataset_class,
|
||||
os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"),
|
||||
)
|
||||
|
||||
logger.info(f"Save inference results of model {model_name} for all dataset.")
|
||||
logger.info(f"Save inference results of all models for all dataset.")
|
||||
@@ -98,6 +113,7 @@ def main(args):
|
||||
)
|
||||
|
||||
inference_data = {}
|
||||
dataset_classes = {}
|
||||
debug_args = {}
|
||||
few_shot_args = {}
|
||||
multiturn_args = {}
|
||||
@@ -128,6 +144,7 @@ def main(args):
|
||||
|
||||
continue
|
||||
|
||||
dataset_classes[dataset_name] = dataset_parameter["dataset_class"]
|
||||
dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}")
|
||||
if not issubclass(dataset_class, dataset.BaseDataset):
|
||||
raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")
|
||||
@@ -149,12 +166,14 @@ def main(args):
|
||||
debug_args[new_dataset_name] = dataset_parameter["debug"]
|
||||
few_shot_args[new_dataset_name] = dataset_parameter["few_shot"]
|
||||
inference_data[new_dataset_name] = dataset_.dataset["train"]
|
||||
dataset_classes[new_dataset_name] = dataset_parameter["dataset_class"]
|
||||
|
||||
if load_reference and "reference" in dataset_.dataset:
|
||||
new_dataset_name = f"{dataset_name}_reference"
|
||||
debug_args[new_dataset_name] = dataset_parameter["debug"]
|
||||
few_shot_args[new_dataset_name] = dataset_parameter["few_shot"]
|
||||
inference_data[new_dataset_name] = dataset_.dataset["reference"]
|
||||
dataset_classes[new_dataset_name] = dataset_parameter["dataset_class"]
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"Dataset for inference are: {list(inference_data.keys())}")
|
||||
@@ -225,7 +244,7 @@ def main(args):
|
||||
if rank == 0:
|
||||
model_names = [model_parameter["name"] for model_parameter in model_parameters]
|
||||
dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
|
||||
rm_and_merge(dp_size, args.inference_save_path, model_names, dataset_names)
|
||||
rm_and_merge(dp_size, args.inference_save_path, model_names, dataset_names, dataset_classes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user