diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index eae4ff54e..8de8b774e 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -127,7 +127,9 @@ class GRPOConsumer(BaseConsumer): "answer_end": {"text": "", "num_occur": 1}, } reward_model_kwargs = { - k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] + k: v + for k, v in grpo_config.items() + if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"] } self.reward_model = VerifiableReward( reward_fns=[ diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 14d340dc4..a4042ae97 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -1,7 +1,77 @@ import torch +from latex2sympy2_extended import NormalizationConfig +from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure +CANNOT_PARSE_GT_ANSWER = -1 +CANNOT_PARSE_PREDICTION = -2 +SUCCESS = 1 +MATCHING_FAIL = 0 + + +def verify_math_representation(completion, gt_answer): + """ + Verify if the completion is a valid math representation of the gt_answer. + """ + if not completion.startswith("\\boxed{"): + completion = "\\boxed{" + completion + "}" + if not gt_answer.startswith("\\boxed{"): + gt_answer = "\\boxed{" + gt_answer + "}" + target = ( + ExprExtractionConfig(), + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + boxed="all", + units=True, + ), + boxed_match_priority=0, + ), + ) + if not isinstance(gt_answer, str) or len(gt_answer) == 0: + raise ValueError("gt_answer should be a string, please verify your training data.") + if not isinstance(completion, str) or len(completion) == 0: + return MATCHING_FAIL + try: + parsed_gt_answer = parse(gt_answer, extraction_config=target) + if len(parsed_gt_answer) == 0: + return CANNOT_PARSE_GT_ANSWER + parsed_completion = parse(completion, extraction_config=target) + if len(parsed_completion) == 0: + return CANNOT_PARSE_PREDICTION + if verify(parsed_gt_answer, parsed_completion): + return SUCCESS + else: + return MATCHING_FAIL + except Exception: + return MATCHING_FAIL + + +def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward): + math_verify_result = verify_math_representation(decoded_final_answer, gt_answer) + exact_match_result = ( + SUCCESS + if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "") + == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "") + else MATCHING_FAIL + ) + if math_verify_result == SUCCESS: + ans_acc += 1 + reward += acc_score + elif exact_match_result == SUCCESS: + # sometimes for answers that's not a (valid) math expression, math_verify will fail + ans_acc += 1 + if math_verify_result == CANNOT_PARSE_PREDICTION: + reward += ( + acc_score / 2 + ) # not a valid latex math representation, but the answer is correct, receive half of the score + else: + reward += acc_score + return reward, ans_acc + def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] @@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): s, e = response_idx[0], response_idx[1] length_reward = 0.0 - if soft_over_length_punishment: - max_length = kwargs.get("max_length", 1024 * 4) - cache_length = kwargs.get("cache_length", 512) - res_length = e.item() - s.item() + 1 - if max_length - cache_length < res_length < max_length: - length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score + res_length = e.item() - s.item() + 1 + if not eval_mode: + max_new_tokens = kwargs["max_new_tokens"] + else: + max_new_tokens = -1 # for eval mode, we don't need to check the length + if not eval_mode and soft_over_length_punishment: + cache_length = kwargs["cache_length"] + if max_new_tokens - cache_length < res_length < max_new_tokens: + length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score if gt_answer is None: - return reward + raise ValueError("no gt_answer is provided, please check your training dataset.") decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) @@ -35,15 +108,15 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_acc += 1 # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if ( - format_valid - and final_answer is not None - and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() - ): - ans_acc += 1 - reward += acc_score + if final_answer is not None: + if eval_mode or format_valid: + reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward) + if not eval_mode: + reward = reward + length_reward - reward = reward + length_reward + # Check if the sequence is over length + if not eval_mode and res_length >= max_new_tokens: + reward *= 0.0 if not eval_mode: return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) @@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): "parsed": final_answer, "format_valid": format_acc.item(), "ans_valid": ans_acc.item(), + "response_length": res_length, + "reward": reward.item(), } @@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): s, e = response_idx[0], response_idx[1] length_reward = 0.0 - if soft_over_length_punishment: - max_length = kwargs.get("max_length", 1024 * 4) - cache_length = kwargs.get("cache_length", 512) - res_length = e.item() - s.item() + 1 - if max_length - cache_length < res_length < max_length: - length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score + res_length = e.item() - s.item() + 1 + if not eval_mode: + max_new_tokens = kwargs["max_new_tokens"] + else: + max_new_tokens = -1 # for eval mode, we don't need to check the length + if not eval_mode and soft_over_length_punishment: + cache_length = kwargs["cache_length"] + if max_new_tokens - cache_length < res_length < max_new_tokens: + length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score if gt_answer is None: - return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + raise ValueError("no gt_answer is provided, please check your training dataset.") decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) final_answer = extract_boxed_solution(decoded_final_answer) format_valid = final_answer is not None + if "tags" in kwargs and kwargs["tags"]: + tags = kwargs["tags"] + format_valid = format_valid and all( + [decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags] + ) # Check format accuracy if format_valid: format_acc += 1 reward += format_score # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower(): - ans_acc += 1 - reward += acc_score + if final_answer is not None: + if eval_mode or format_valid: + reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward) + if not eval_mode: + reward = reward + length_reward + + # Check if the sequence is over length + if not eval_mode and res_length >= max_new_tokens: + reward *= 0.0 - reward = reward + length_reward if not eval_mode: return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) else: @@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): "parsed": final_answer, "format_valid": format_acc.item(), "ans_valid": ans_acc.item(), + "response_length": res_length, + "reward": reward.item(), } diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 8d1f25e74..071912ddf 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -198,6 +198,8 @@ if __name__ == "__main__": "beta": args.kl_coeff, # KL penalty coefficient "loss_variation": "sample_level", "reward_fn_type": args.reward_type, + "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, } elif args.algo == "DAPO": # DAPO variant settings @@ -213,6 +215,7 @@ if __name__ == "__main__": "loss_variation": "token_level", "soft_over_length_punishment": True, "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, "cache_length": min(1024, int(args.max_new_tokens / 4)), "filter_truncated_response": True, "reward_fn_type": args.reward_type,