Support mtbench (#5025)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
This commit is contained in:
Yuanchen
2023-11-09 13:41:50 +08:00
committed by GitHub
parent f71e63b0f3
commit 239cd92eff
9 changed files with 312 additions and 13 deletions

View File

@@ -71,6 +71,7 @@ def main(args):
inference_data = {}
debug_args = {}
few_shot_args = {}
multiturn_args = {}
config = utils.jload(args.config)
@@ -102,6 +103,13 @@ def main(args):
dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
dataset_.save(save_path)
if hasattr(dataset_, "multiturn") and dataset_.multiturn:
multiturn_args[dataset_name] = True
logger.info(f"{dataset_parameter['dataset_class']} is a multiturn dataset.")
else:
multiturn_args[dataset_name] = False
inference_data[dataset_name] = dataset_.dataset["test"]
for model_parameter in model_parameters:
@@ -117,7 +125,10 @@ def main(args):
for dataset_name, split_data in inference_data.items():
start = 0
prev_questions = None
for category, category_data in split_data.items():
num_turn = category_data["inference_kwargs"].get("turns", 1)
if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
@@ -132,11 +143,16 @@ def main(args):
start = (start + redundant) % world_size
questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
for turn in range(num_turn):
if turn == 0:
questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
else:
questions = prev_questions
answers_per_rank = model_.inference(
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
)
answers_per_rank = model_.inference(
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
)
prev_questions = answers_per_rank
answers_to_dump["data"] = answers_per_rank