Support mtbench (#5025)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
This commit is contained in:
Yuanchen
2023-11-09 13:41:50 +08:00
committed by GitHub
parent f71e63b0f3
commit 239cd92eff
9 changed files with 312 additions and 13 deletions

View File

@@ -333,9 +333,12 @@ class HuggingFaceModel(BaseModel):
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1
turn_desc = "" if turn == 0 else f"-turn{turn}"
bar = tqdm(
range(math.ceil(len(data) / self.batch_size)),
desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps",
desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps",
disable=not is_rank_0(),
)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@@ -384,7 +387,10 @@ class HuggingFaceModel(BaseModel):
for j in range(len(batch_prompt)):
if not pretrain:
answers[i + j]["output"] = batch_decodes[j].strip()
if isinstance(answers[i + j]["output"], list):
answers[i + j]["output"].append(batch_decodes[j].strip())
else:
answers[i + j]["output"] = batch_decodes[j].strip()
if isinstance(scores, torch.Tensor):
answers[i + j]["softmax_over_choices"] = probs[j]