mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
fix schedualing for multi-node training
This commit is contained in:
@@ -14,6 +14,10 @@ def verify_math_representation(completion, gt_answer):
|
||||
"""
|
||||
Verify if the completion is a valid math representation of the gt_answer.
|
||||
"""
|
||||
if not completion.startswith("\\boxed{"):
|
||||
completion = "\\boxed{" + completion + "}"
|
||||
if not gt_answer.startswith("\\boxed{"):
|
||||
gt_answer = "\\boxed{" + gt_answer + "}"
|
||||
target = (
|
||||
ExprExtractionConfig(),
|
||||
LatexExtractionConfig(
|
||||
@@ -59,7 +63,7 @@ def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, rew
|
||||
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
|
||||
# plain text answer cannot be parsed, but is correct
|
||||
reward += acc_score
|
||||
else:
|
||||
elif math_verify_result == CANNOT_PARSE_PREDICTION:
|
||||
reward += (
|
||||
acc_score / 2
|
||||
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
||||
@@ -140,9 +144,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||
format_valid = final_answer is not None
|
||||
if "tags" in kwargs:
|
||||
tags = kwargs["tags"]
|
||||
format_valid = format_valid and all(
|
||||
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
|
||||
)
|
||||
# Check format accuracy
|
||||
if format_valid:
|
||||
format_acc += 1
|
||||
|
Reference in New Issue
Block a user