From 94743161326d2ede61326950d59efc8ed2119d0f Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 16 Apr 2025 15:59:25 +0800 Subject: [PATCH] small fix --- .../coati/distributed/grpo_consumer.py | 26 ++++++++----------- .../coati/distributed/reward/reward_fn.py | 5 ++-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index dee4e648b..5886cc7fe 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -206,12 +206,6 @@ class GRPOConsumer(BaseConsumer): total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() self.total_sample_count += total_samples.item() - print( - loss_mask, - self.effective_sample_count, - self.total_sample_count, - self.batch_size * self.dp_size * self.num_generations * 0.75, - ) mean_kl, mean_loss = [], [] @@ -426,17 +420,19 @@ class GRPOConsumer(BaseConsumer): self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): to_log_msg = ( - f"Loss: {self.accum_loss.item() / self.accum_count:.4f} \ - Reward: {self.accum_reward.item() / self.accum_count:.4f} \ - Format Reward: {self.accum_format_acc.item() / self.accum_count:.4f} \ - Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f} \ - Advantages: {self.accum_advantages.item() / self.accum_count:.4f} \ - Response Length: {self.accum_response_length.item() / self.accum_count:.4f}" - + f" KL: {self.accum_kl.item() / self.accum_count:.4f}" + [ + f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", + f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", + f"ormat Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", + f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", + f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", + f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", + ] + + [f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 - else "" + else [] ) - print(to_log_msg) + print("\n".join(to_log_msg)) metrics = { "metrics/reward": self.accum_reward.item() / self.accum_count, "metrics/format_acc": self.accum_format_acc.item() / self.accum_count, diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index b1ac02fcd..3cf7a1af3 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -35,9 +35,10 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_acc += 1 reward += format_score - # Check answer accuracy + # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid if ( - final_answer is not None + format_valid + and final_answer is not None and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() ): ans_acc += 1