update consumer

This commit is contained in:
Chen Li 2025-05-14 18:19:47 +08:00
parent 13c2676612
commit 18f2247a10
3 changed files with 21 additions and 10 deletions

View File

@ -7,7 +7,8 @@ import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.agpo_reward import AGPOReward
from coati.distributed.reward.reward_fn import agpo_boxed_math_reward_fn, boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
@ -132,14 +133,23 @@ class GRPOConsumer(BaseConsumer):
reward_model_kwargs = {
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
}
self.reward_model = VerifiableReward(
reward_fns=[
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
],
tokenizer=self.tokenizer,
tags=response_format_tags,
**reward_model_kwargs,
)
if self.grpo_config.get("correct_sample_length_reward", False):
self.reward_model = AGPOReward(
reward_fn=agpo_boxed_math_reward_fn,
num_generations=self.num_generations,
tokenizer=self.tokenizer,
tags=response_format_tags,
**reward_model_kwargs,
)
else:
self.reward_model = VerifiableReward(
reward_fns=[
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
],
tokenizer=self.tokenizer,
tags=response_format_tags,
**reward_model_kwargs,
)
self.global_step = 0
self.use_wandb = use_wandb

View File

@ -20,7 +20,7 @@ class AGPOReward:
) -> torch.Tensor:
# Get batch size
bs = input_ids.size(0)
num_generations = self.kwargs.get("num_generations")
num_generations = self.kwargs["num_generations"]
# Apply the reward function to the entire batch at once
reward_infos = [self.reward_fn(input_ids[i], gt_answer=gt_answer[i], response_idx=response_idx[i], **self.kwargs) for i in range(bs)]

View File

@ -209,6 +209,7 @@ if __name__ == "__main__":
"loss_variation": "token_level",
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"reward_fn_type": args.reward_type,
"correct_sample_length_reward": True
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")