mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[feat] Support boxed math reward (#6284)
* fix pp+tp, fix dataloader
* fixed plugin micro-batch size
* support boxed reward
* add boxed reward
* fix pp state dict incomplete issue
* Revert "fix pp state dict incomplete issue"
This reverts commit 6c1b3b694f
.
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
import wandb
|
||||
from coati.distributed.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
@@ -54,7 +54,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
and "num_microbatches" not in plugin_config
|
||||
and "microbatch_size" not in plugin_config
|
||||
):
|
||||
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
|
||||
plugin_config["microbatch_size"] = max(
|
||||
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
|
||||
)
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
@@ -131,7 +133,12 @@ class GRPOConsumer(BaseConsumer):
|
||||
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
||||
}
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
|
||||
reward_fns=[
|
||||
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
|
||||
],
|
||||
tokenizer=self.tokenizer,
|
||||
tags=response_format_tags,
|
||||
**reward_model_kwargs,
|
||||
)
|
||||
self.global_step = 0
|
||||
self.use_wandb = use_wandb
|
||||
|
Reference in New Issue
Block a user