diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 877ff98ec..a86315a24 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -7,7 +7,8 @@ import torch import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn +from coati.distributed.reward.agpo_reward import AGPOReward +from coati.distributed.reward.reward_fn import agpo_boxed_math_reward_fn, boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean, all_reduce_sum @@ -132,14 +133,23 @@ class GRPOConsumer(BaseConsumer): reward_model_kwargs = { k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] } - self.reward_model = VerifiableReward( - reward_fns=[ - math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn - ], - tokenizer=self.tokenizer, - tags=response_format_tags, - **reward_model_kwargs, - ) + if self.grpo_config.get("correct_sample_length_reward", False): + self.reward_model = AGPOReward( + reward_fn=agpo_boxed_math_reward_fn, + num_generations=self.num_generations, + tokenizer=self.tokenizer, + tags=response_format_tags, + **reward_model_kwargs, + ) + else: + self.reward_model = VerifiableReward( + reward_fns=[ + math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn + ], + tokenizer=self.tokenizer, + tags=response_format_tags, + **reward_model_kwargs, + ) self.global_step = 0 self.use_wandb = use_wandb diff --git a/applications/ColossalChat/coati/distributed/reward/agpo_reward.py b/applications/ColossalChat/coati/distributed/reward/agpo_reward.py index ae2b01756..994b36373 100644 --- a/applications/ColossalChat/coati/distributed/reward/agpo_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/agpo_reward.py @@ -20,7 +20,7 @@ class AGPOReward: ) -> torch.Tensor: # Get batch size bs = input_ids.size(0) - num_generations = self.kwargs.get("num_generations") + num_generations = self.kwargs["num_generations"] # Apply the reward function to the entire batch at once reward_infos = [self.reward_fn(input_ids[i], gt_answer=gt_answer[i], response_idx=response_idx[i], **self.kwargs) for i in range(bs)] diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index d6565b922..3b36a153a 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -209,6 +209,7 @@ if __name__ == "__main__": "loss_variation": "token_level", "max_length": args.max_new_tokens + args.max_prompt_tokens, "reward_fn_type": args.reward_type, + "correct_sample_length_reward": True } else: raise ValueError(f"Unsupported algorithm: {args.algo}")