mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
fix metric calculation
This commit is contained in:
@@ -72,12 +72,12 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_format_acc = torch.zeros(1, device=self.device)
|
||||
self.accum_ans_acc = torch.zeros(1, device=self.device)
|
||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||
self.raw_train_batch_reward = []
|
||||
self.raw_train_batch_format_acc = []
|
||||
self.raw_train_batch_ans_acc = []
|
||||
self.raw_train_batch_response_len = []
|
||||
self.accum_count = 0
|
||||
self.generate_config = generate_config
|
||||
self.grpo_config = grpo_config
|
||||
@@ -186,7 +186,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||
"""
|
||||
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
|
||||
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
|
||||
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
|
||||
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
|
||||
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
|
||||
action_mask = data["action_mask"]
|
||||
num_action = action_mask.shape[1]
|
||||
old_action_log_probs = data["action_log_probs"]
|
||||
@@ -430,11 +434,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
self.accum_reward.add_(reward.data)
|
||||
self.accum_format_acc.add_(format_acc.data)
|
||||
self.accum_ans_acc.add_(ans_acc.data)
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
self.accum_response_length.add_(response_length.data)
|
||||
self.accum_count += 1
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
@@ -452,21 +452,33 @@ class GRPOConsumer(BaseConsumer):
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
raw_batch_reward_mean = sum(self.raw_train_batch_reward) / len(self.raw_train_batch_reward)
|
||||
raw_batch_format_acc_mean = sum(self.raw_train_batch_format_acc) / len(
|
||||
self.raw_train_batch_format_acc
|
||||
)
|
||||
raw_batch_ans_acc_mean = sum(self.raw_train_batch_ans_acc) / len(self.raw_train_batch_ans_acc)
|
||||
raw_batch_response_len_mean = sum(self.raw_train_batch_response_len) / len(
|
||||
self.raw_train_batch_response_len
|
||||
)
|
||||
self.raw_train_batch_reward = []
|
||||
self.raw_train_batch_format_acc = []
|
||||
self.raw_train_batch_ans_acc = []
|
||||
self.raw_train_batch_response_len = []
|
||||
to_log_msg = [
|
||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
|
||||
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
|
||||
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
||||
f"Reward: {raw_batch_reward_mean:.4f}",
|
||||
f"format Reward: {raw_batch_format_acc_mean:.4f}",
|
||||
f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
|
||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||
f"Sample_utilization: {sample_utilization:.4f}",
|
||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||
print("\n".join(to_log_msg))
|
||||
metrics = {
|
||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
||||
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
|
||||
"metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count,
|
||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
||||
"metrics/reward": raw_batch_reward_mean,
|
||||
"metrics/format_acc": raw_batch_format_acc_mean,
|
||||
"metrics/ans_acc": raw_batch_ans_acc_mean,
|
||||
"metrics/response_length": raw_batch_response_len_mean,
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
@@ -478,12 +490,8 @@ class GRPOConsumer(BaseConsumer):
|
||||
if self.wandb_run is not None:
|
||||
self.wandb_run.log(metrics)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_ans_acc.zero_()
|
||||
self.accum_format_acc.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
else:
|
||||
|
Reference in New Issue
Block a user