mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-10 12:22:28 +00:00
upgrade reward functions
This commit is contained in:
parent
021914c565
commit
03b41d6fb5
@ -127,7 +127,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
reward_model_kwargs = {
|
||||
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
||||
k: v
|
||||
for k, v in grpo_config.items()
|
||||
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
|
||||
}
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[
|
||||
|
@ -1,7 +1,77 @@
|
||||
import torch
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
||||
|
||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||
|
||||
CANNOT_PARSE_GT_ANSWER = -1
|
||||
CANNOT_PARSE_PREDICTION = -2
|
||||
SUCCESS = 1
|
||||
MATCHING_FAIL = 0
|
||||
|
||||
|
||||
def verify_math_representation(completion, gt_answer):
|
||||
"""
|
||||
Verify if the completion is a valid math representation of the gt_answer.
|
||||
"""
|
||||
if not completion.startswith("\\boxed{"):
|
||||
completion = "\\boxed{" + completion + "}"
|
||||
if not gt_answer.startswith("\\boxed{"):
|
||||
gt_answer = "\\boxed{" + gt_answer + "}"
|
||||
target = (
|
||||
ExprExtractionConfig(),
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
),
|
||||
)
|
||||
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
|
||||
raise ValueError("gt_answer should be a string, please verify your training data.")
|
||||
if not isinstance(completion, str) or len(completion) == 0:
|
||||
return MATCHING_FAIL
|
||||
try:
|
||||
parsed_gt_answer = parse(gt_answer, extraction_config=target)
|
||||
if len(parsed_gt_answer) == 0:
|
||||
return CANNOT_PARSE_GT_ANSWER
|
||||
parsed_completion = parse(completion, extraction_config=target)
|
||||
if len(parsed_completion) == 0:
|
||||
return CANNOT_PARSE_PREDICTION
|
||||
if verify(parsed_gt_answer, parsed_completion):
|
||||
return SUCCESS
|
||||
else:
|
||||
return MATCHING_FAIL
|
||||
except Exception:
|
||||
return MATCHING_FAIL
|
||||
|
||||
|
||||
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
|
||||
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
|
||||
exact_match_result = (
|
||||
SUCCESS
|
||||
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
|
||||
== gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
|
||||
else MATCHING_FAIL
|
||||
)
|
||||
if math_verify_result == SUCCESS:
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
elif exact_match_result == SUCCESS:
|
||||
# sometimes for answers that's not a (valid) math expression, math_verify will fail
|
||||
ans_acc += 1
|
||||
if math_verify_result == CANNOT_PARSE_PREDICTION:
|
||||
reward += (
|
||||
acc_score / 2
|
||||
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
||||
else:
|
||||
reward += acc_score
|
||||
return reward, ans_acc
|
||||
|
||||
|
||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
s, e = response_idx[0], response_idx[1]
|
||||
|
||||
length_reward = 0.0
|
||||
if soft_over_length_punishment:
|
||||
max_length = kwargs.get("max_length", 1024 * 4)
|
||||
cache_length = kwargs.get("cache_length", 512)
|
||||
res_length = e.item() - s.item() + 1
|
||||
if max_length - cache_length < res_length < max_length:
|
||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||
res_length = e.item() - s.item() + 1
|
||||
if not eval_mode:
|
||||
max_new_tokens = kwargs["max_new_tokens"]
|
||||
else:
|
||||
max_new_tokens = -1 # for eval mode, we don't need to check the length
|
||||
if not eval_mode and soft_over_length_punishment:
|
||||
cache_length = kwargs["cache_length"]
|
||||
if max_new_tokens - cache_length < res_length < max_new_tokens:
|
||||
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
|
||||
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
raise ValueError("no gt_answer is provided, please check your training dataset.")
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||
@ -35,15 +108,15 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
format_acc += 1
|
||||
|
||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||
if (
|
||||
format_valid
|
||||
and final_answer is not None
|
||||
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
||||
):
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
if final_answer is not None:
|
||||
if eval_mode or format_valid:
|
||||
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||
if not eval_mode:
|
||||
reward = reward + length_reward
|
||||
|
||||
reward = reward + length_reward
|
||||
# Check if the sequence is over length
|
||||
if not eval_mode and res_length >= max_new_tokens:
|
||||
reward *= 0.0
|
||||
|
||||
if not eval_mode:
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
"parsed": final_answer,
|
||||
"format_valid": format_acc.item(),
|
||||
"ans_valid": ans_acc.item(),
|
||||
"response_length": res_length,
|
||||
"reward": reward.item(),
|
||||
}
|
||||
|
||||
|
||||
@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
s, e = response_idx[0], response_idx[1]
|
||||
|
||||
length_reward = 0.0
|
||||
if soft_over_length_punishment:
|
||||
max_length = kwargs.get("max_length", 1024 * 4)
|
||||
cache_length = kwargs.get("cache_length", 512)
|
||||
res_length = e.item() - s.item() + 1
|
||||
if max_length - cache_length < res_length < max_length:
|
||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||
res_length = e.item() - s.item() + 1
|
||||
if not eval_mode:
|
||||
max_new_tokens = kwargs["max_new_tokens"]
|
||||
else:
|
||||
max_new_tokens = -1 # for eval mode, we don't need to check the length
|
||||
if not eval_mode and soft_over_length_punishment:
|
||||
cache_length = kwargs["cache_length"]
|
||||
if max_new_tokens - cache_length < res_length < max_new_tokens:
|
||||
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
|
||||
|
||||
if gt_answer is None:
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
raise ValueError("no gt_answer is provided, please check your training dataset.")
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||
format_valid = final_answer is not None
|
||||
if "tags" in kwargs and kwargs["tags"]:
|
||||
tags = kwargs["tags"]
|
||||
format_valid = format_valid and all(
|
||||
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
|
||||
)
|
||||
# Check format accuracy
|
||||
if format_valid:
|
||||
format_acc += 1
|
||||
reward += format_score
|
||||
|
||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
if final_answer is not None:
|
||||
if eval_mode or format_valid:
|
||||
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||
if not eval_mode:
|
||||
reward = reward + length_reward
|
||||
|
||||
# Check if the sequence is over length
|
||||
if not eval_mode and res_length >= max_new_tokens:
|
||||
reward *= 0.0
|
||||
|
||||
reward = reward + length_reward
|
||||
if not eval_mode:
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
else:
|
||||
@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
"parsed": final_answer,
|
||||
"format_valid": format_acc.item(),
|
||||
"ans_valid": ans_acc.item(),
|
||||
"response_length": res_length,
|
||||
"reward": reward.item(),
|
||||
}
|
||||
|
@ -198,6 +198,8 @@ if __name__ == "__main__":
|
||||
"beta": args.kl_coeff, # KL penalty coefficient
|
||||
"loss_variation": "sample_level",
|
||||
"reward_fn_type": args.reward_type,
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
}
|
||||
elif args.algo == "DAPO":
|
||||
# DAPO variant settings
|
||||
@ -213,6 +215,7 @@ if __name__ == "__main__":
|
||||
"loss_variation": "token_level",
|
||||
"soft_over_length_punishment": True,
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||
"filter_truncated_response": True,
|
||||
"reward_fn_type": args.reward_type,
|
||||
|
Loading…
Reference in New Issue
Block a user