[feat] Support boxed math reward (#6284)

* fix pp+tp, fix dataloader

* fixed plugin micro-batch size

* support boxed reward

* add boxed reward

* fix pp state dict incomplete issue

* Revert "fix pp state dict incomplete issue"

This reverts commit 6c1b3b694f.
This commit is contained in:
YeAnbang
2025-04-29 16:46:47 +08:00
committed by GitHub
parent 2ca1e3c630
commit 14f237ce7e
5 changed files with 118 additions and 12 deletions

View File

@@ -7,7 +7,7 @@ import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
@@ -54,7 +54,9 @@ class GRPOConsumer(BaseConsumer):
and "num_microbatches" not in plugin_config
and "microbatch_size" not in plugin_config
):
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
plugin_config["microbatch_size"] = max(
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
)
super().__init__(
num_producers,
num_episodes,
@@ -131,7 +133,12 @@ class GRPOConsumer(BaseConsumer):
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
reward_fns=[
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
],
tokenizer=self.tokenizer,
tags=response_format_tags,
**reward_model_kwargs,
)
self.global_step = 0
self.use_wandb = use_wandb