fix schedualing for multi-node training

This commit is contained in:
YeAnbang
2025-05-02 19:45:07 +08:00
parent d06042b434
commit 7d658402da
7 changed files with 124 additions and 38 deletions

View File

@@ -14,6 +14,10 @@ def verify_math_representation(completion, gt_answer):
"""
Verify if the completion is a valid math representation of the gt_answer.
"""
if not completion.startswith("\\boxed{"):
completion = "\\boxed{" + completion + "}"
if not gt_answer.startswith("\\boxed{"):
gt_answer = "\\boxed{" + gt_answer + "}"
target = (
ExprExtractionConfig(),
LatexExtractionConfig(
@@ -59,7 +63,7 @@ def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, rew
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
# plain text answer cannot be parsed, but is correct
reward += acc_score
else:
elif math_verify_result == CANNOT_PARSE_PREDICTION:
reward += (
acc_score / 2
) # not a valid latex math representation, but the answer is correct, receive half of the score
@@ -140,9 +144,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None
if "tags" in kwargs:
tags = kwargs["tags"]
format_valid = format_valid and all(
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
)
# Check format accuracy
if format_valid:
format_acc += 1