mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[feat] Support boxed math reward (#6284)
* fix pp+tp, fix dataloader
* fixed plugin micro-batch size
* support boxed reward
* add boxed reward
* fix pp state dict incomplete issue
* Revert "fix pp state dict incomplete issue"
This reverts commit 6c1b3b694f
.
This commit is contained in:
@@ -71,7 +71,7 @@ class BaseConsumer:
|
||||
and "num_microbatches" not in self.plugin_config
|
||||
and "microbatch_size" not in self.plugin_config
|
||||
):
|
||||
plugin_config["microbatch_size"] = self.minibatch_size
|
||||
plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1))
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
|
@@ -7,7 +7,7 @@ import torch
|
||||
import wandb
|
||||
from coati.distributed.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
@@ -54,7 +54,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
and "num_microbatches" not in plugin_config
|
||||
and "microbatch_size" not in plugin_config
|
||||
):
|
||||
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
|
||||
plugin_config["microbatch_size"] = max(
|
||||
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
|
||||
)
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
@@ -131,7 +133,12 @@ class GRPOConsumer(BaseConsumer):
|
||||
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
||||
}
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
|
||||
reward_fns=[
|
||||
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
|
||||
],
|
||||
tokenizer=self.tokenizer,
|
||||
tags=response_format_tags,
|
||||
**reward_model_kwargs,
|
||||
)
|
||||
self.global_step = 0
|
||||
self.use_wandb = use_wandb
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from .reward_utils import extract_solution, validate_response_structure
|
||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||
|
||||
|
||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
@@ -70,3 +70,43 @@ def gsm8k_reward_fn(input_ids, **kwargs):
|
||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||
reward = reward + 9.0
|
||||
return reward
|
||||
|
||||
|
||||
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||
format_score = 0.0
|
||||
acc_score = 10.0
|
||||
reward = torch.tensor(0.0)
|
||||
format_acc = torch.tensor(0.0)
|
||||
ans_acc = torch.tensor(0.0)
|
||||
s, e = response_idx[0], response_idx[1]
|
||||
|
||||
length_reward = 0.0
|
||||
if soft_over_length_punishment:
|
||||
max_length = kwargs.get("max_length", 1024 * 4)
|
||||
cache_length = kwargs.get("cache_length", 512)
|
||||
res_length = e.item() - s.item() + 1
|
||||
if max_length - cache_length < res_length < max_length:
|
||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||
format_valid = final_answer is not None
|
||||
# Check format accuracy
|
||||
if format_valid:
|
||||
format_acc += 1
|
||||
reward += format_score
|
||||
|
||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
|
||||
reward = reward + length_reward
|
||||
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
@@ -74,3 +74,51 @@ def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
|
||||
|
||||
final_answer = matches[-1].group(1).strip()
|
||||
return final_answer, solution_str
|
||||
|
||||
|
||||
def extract_boxed_solution(text: str) -> Optional[str]:
|
||||
"""
|
||||
Modified from: https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3
|
||||
Retrieves the content from the last occurrence of `\boxed{}` in a LaTeX-like string.
|
||||
|
||||
Args:
|
||||
text (str): A string potentially containing LaTeX-style boxed expressions.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The text inside the final `\boxed{}` if successfully extracted;
|
||||
returns `None` if no properly closed box is found.
|
||||
|
||||
Examples:
|
||||
>>> extract_boxed_solution("The answer is \\boxed{42}.")
|
||||
'42'
|
||||
>>> extract_boxed_solution("Here is an unmatched \\boxed{42")
|
||||
None
|
||||
"""
|
||||
try:
|
||||
# Find the last occurrence of "\boxed{"
|
||||
start_idx = text.rindex("\\boxed{")
|
||||
# Move past "\boxed{" to find the start of the content
|
||||
content_start = start_idx + len("\\boxed{")
|
||||
open_braces = 1
|
||||
pos = content_start
|
||||
|
||||
# Traverse the string to find the matching closing brace
|
||||
while open_braces > 0 and pos < len(text):
|
||||
if text[pos] == "{":
|
||||
open_braces += 1
|
||||
elif text[pos] == "}":
|
||||
open_braces -= 1
|
||||
pos += 1
|
||||
|
||||
# If all braces are matched, extract and return the content
|
||||
if open_braces == 0:
|
||||
return text[content_start : pos - 1].strip()
|
||||
else:
|
||||
return None
|
||||
|
||||
except ValueError:
|
||||
# "\boxed{" not found
|
||||
return None
|
||||
except Exception:
|
||||
# Any other unexpected error
|
||||
return None
|
||||
|
Reference in New Issue
Block a user