mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-15 22:53:12 +00:00
small fix
This commit is contained in:
parent
cc4faa7300
commit
9474316132
@ -206,12 +206,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||||
self.effective_sample_count += effective_samples.item()
|
self.effective_sample_count += effective_samples.item()
|
||||||
self.total_sample_count += total_samples.item()
|
self.total_sample_count += total_samples.item()
|
||||||
print(
|
|
||||||
loss_mask,
|
|
||||||
self.effective_sample_count,
|
|
||||||
self.total_sample_count,
|
|
||||||
self.batch_size * self.dp_size * self.num_generations * 0.75,
|
|
||||||
)
|
|
||||||
|
|
||||||
mean_kl, mean_loss = [], []
|
mean_kl, mean_loss = [], []
|
||||||
|
|
||||||
@ -426,17 +420,19 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
):
|
):
|
||||||
to_log_msg = (
|
to_log_msg = (
|
||||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f} \
|
[
|
||||||
Reward: {self.accum_reward.item() / self.accum_count:.4f} \
|
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||||
Format Reward: {self.accum_format_acc.item() / self.accum_count:.4f} \
|
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
|
||||||
Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f} \
|
f"ormat Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
|
||||||
Advantages: {self.accum_advantages.item() / self.accum_count:.4f} \
|
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
||||||
Response Length: {self.accum_response_length.item() / self.accum_count:.4f}"
|
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||||
+ f" KL: {self.accum_kl.item() / self.accum_count:.4f}"
|
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||||
|
]
|
||||||
|
+ [f"KL: {self.accum_kl.item() / self.accum_count:.4f}"]
|
||||||
if self.policy_loss_fn.beta > 0
|
if self.policy_loss_fn.beta > 0
|
||||||
else ""
|
else []
|
||||||
)
|
)
|
||||||
print(to_log_msg)
|
print("\n".join(to_log_msg))
|
||||||
metrics = {
|
metrics = {
|
||||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
||||||
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
|
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
|
||||||
|
@ -35,9 +35,10 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
format_acc += 1
|
format_acc += 1
|
||||||
reward += format_score
|
reward += format_score
|
||||||
|
|
||||||
# Check answer accuracy
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
if (
|
if (
|
||||||
final_answer is not None
|
format_valid
|
||||||
|
and final_answer is not None
|
||||||
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
||||||
):
|
):
|
||||||
ans_acc += 1
|
ans_acc += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user