mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-25 12:35:02 +00:00
update consumer
This commit is contained in:
parent
13c2676612
commit
18f2247a10
@ -7,7 +7,8 @@ import torch
|
||||
import wandb
|
||||
from coati.distributed.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.agpo_reward import AGPOReward
|
||||
from coati.distributed.reward.reward_fn import agpo_boxed_math_reward_fn, boxed_math_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
@ -132,14 +133,23 @@ class GRPOConsumer(BaseConsumer):
|
||||
reward_model_kwargs = {
|
||||
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
||||
}
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[
|
||||
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
|
||||
],
|
||||
tokenizer=self.tokenizer,
|
||||
tags=response_format_tags,
|
||||
**reward_model_kwargs,
|
||||
)
|
||||
if self.grpo_config.get("correct_sample_length_reward", False):
|
||||
self.reward_model = AGPOReward(
|
||||
reward_fn=agpo_boxed_math_reward_fn,
|
||||
num_generations=self.num_generations,
|
||||
tokenizer=self.tokenizer,
|
||||
tags=response_format_tags,
|
||||
**reward_model_kwargs,
|
||||
)
|
||||
else:
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[
|
||||
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
|
||||
],
|
||||
tokenizer=self.tokenizer,
|
||||
tags=response_format_tags,
|
||||
**reward_model_kwargs,
|
||||
)
|
||||
self.global_step = 0
|
||||
self.use_wandb = use_wandb
|
||||
|
||||
|
@ -20,7 +20,7 @@ class AGPOReward:
|
||||
) -> torch.Tensor:
|
||||
# Get batch size
|
||||
bs = input_ids.size(0)
|
||||
num_generations = self.kwargs.get("num_generations")
|
||||
num_generations = self.kwargs["num_generations"]
|
||||
|
||||
# Apply the reward function to the entire batch at once
|
||||
reward_infos = [self.reward_fn(input_ids[i], gt_answer=gt_answer[i], response_idx=response_idx[i], **self.kwargs) for i in range(bs)]
|
||||
|
@ -209,6 +209,7 @@ if __name__ == "__main__":
|
||||
"loss_variation": "token_level",
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"reward_fn_type": args.reward_type,
|
||||
"correct_sample_length_reward": True
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||
|
Loading…
Reference in New Issue
Block a user