mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
Support mtbench (#5025)
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
This commit is contained in:
@@ -71,6 +71,7 @@ def main(args):
|
||||
inference_data = {}
|
||||
debug_args = {}
|
||||
few_shot_args = {}
|
||||
multiturn_args = {}
|
||||
|
||||
config = utils.jload(args.config)
|
||||
|
||||
@@ -102,6 +103,13 @@ def main(args):
|
||||
dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
|
||||
|
||||
dataset_.save(save_path)
|
||||
|
||||
if hasattr(dataset_, "multiturn") and dataset_.multiturn:
|
||||
multiturn_args[dataset_name] = True
|
||||
logger.info(f"{dataset_parameter['dataset_class']} is a multiturn dataset.")
|
||||
else:
|
||||
multiturn_args[dataset_name] = False
|
||||
|
||||
inference_data[dataset_name] = dataset_.dataset["test"]
|
||||
|
||||
for model_parameter in model_parameters:
|
||||
@@ -117,7 +125,10 @@ def main(args):
|
||||
|
||||
for dataset_name, split_data in inference_data.items():
|
||||
start = 0
|
||||
prev_questions = None
|
||||
for category, category_data in split_data.items():
|
||||
num_turn = category_data["inference_kwargs"].get("turns", 1)
|
||||
|
||||
if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
|
||||
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
|
||||
|
||||
@@ -132,11 +143,16 @@ def main(args):
|
||||
|
||||
start = (start + redundant) % world_size
|
||||
|
||||
questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
|
||||
for turn in range(num_turn):
|
||||
if turn == 0:
|
||||
questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
|
||||
else:
|
||||
questions = prev_questions
|
||||
|
||||
answers_per_rank = model_.inference(
|
||||
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
|
||||
)
|
||||
answers_per_rank = model_.inference(
|
||||
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
|
||||
)
|
||||
prev_questions = answers_per_rank
|
||||
|
||||
answers_to_dump["data"] = answers_per_rank
|
||||
|
||||
|
Reference in New Issue
Block a user