[feat] Support boxed math reward (#6284)

* fix pp+tp, fix dataloader

* fixed plugin micro-batch size

* support boxed reward

* add boxed reward

* fix pp state dict incomplete issue

* Revert "fix pp state dict incomplete issue"

This reverts commit 6c1b3b694f.
This commit is contained in:
YeAnbang
2025-04-29 16:46:47 +08:00
committed by GitHub
parent 2ca1e3c630
commit 14f237ce7e
5 changed files with 118 additions and 12 deletions

View File

@@ -71,7 +71,7 @@ class BaseConsumer:
and "num_microbatches" not in self.plugin_config
and "microbatch_size" not in self.plugin_config
):
plugin_config["microbatch_size"] = self.minibatch_size
plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1))
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)

View File

@@ -7,7 +7,7 @@ import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
@@ -54,7 +54,9 @@ class GRPOConsumer(BaseConsumer):
and "num_microbatches" not in plugin_config
and "microbatch_size" not in plugin_config
):
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
plugin_config["microbatch_size"] = max(
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
)
super().__init__(
num_producers,
num_episodes,
@@ -131,7 +133,12 @@ class GRPOConsumer(BaseConsumer):
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
reward_fns=[
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
],
tokenizer=self.tokenizer,
tags=response_format_tags,
**reward_model_kwargs,
)
self.global_step = 0
self.use_wandb = use_wandb

View File

@@ -1,6 +1,6 @@
import torch
from .reward_utils import extract_solution, validate_response_structure
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
@@ -70,3 +70,43 @@ def gsm8k_reward_fn(input_ids, **kwargs):
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
reward = reward + 9.0
return reward
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
format_score = 0.0
acc_score = 10.0
reward = torch.tensor(0.0)
format_acc = torch.tensor(0.0)
ans_acc = torch.tensor(0.0)
s, e = response_idx[0], response_idx[1]
length_reward = 0.0
if soft_over_length_punishment:
max_length = kwargs.get("max_length", 1024 * 4)
cache_length = kwargs.get("cache_length", 512)
res_length = e.item() - s.item() + 1
if max_length - cache_length < res_length < max_length:
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None
# Check format accuracy
if format_valid:
format_acc += 1
reward += format_score
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
ans_acc += 1
reward += acc_score
reward = reward + length_reward
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)

View File

@@ -74,3 +74,51 @@ def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
final_answer = matches[-1].group(1).strip()
return final_answer, solution_str
def extract_boxed_solution(text: str) -> Optional[str]:
"""
Modified from: https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3
Retrieves the content from the last occurrence of `\boxed{}` in a LaTeX-like string.
Args:
text (str): A string potentially containing LaTeX-style boxed expressions.
Returns:
Optional[str]: The text inside the final `\boxed{}` if successfully extracted;
returns `None` if no properly closed box is found.
Examples:
>>> extract_boxed_solution("The answer is \\boxed{42}.")
'42'
>>> extract_boxed_solution("Here is an unmatched \\boxed{42")
None
"""
try:
# Find the last occurrence of "\boxed{"
start_idx = text.rindex("\\boxed{")
# Move past "\boxed{" to find the start of the content
content_start = start_idx + len("\\boxed{")
open_braces = 1
pos = content_start
# Traverse the string to find the matching closing brace
while open_braces > 0 and pos < len(text):
if text[pos] == "{":
open_braces += 1
elif text[pos] == "}":
open_braces -= 1
pos += 1
# If all braces are matched, extract and return the content
if open_braces == 0:
return text[content_start : pos - 1].strip()
else:
return None
except ValueError:
# "\boxed{" not found
return None
except Exception:
# Any other unexpected error
return None