add response length

This commit is contained in:
Tong Li 2025-03-11 13:06:09 +08:00
parent abca66e69f
commit 47d6493778

View File

@ -59,6 +59,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_format_reward = torch.zeros(1, device=self.device) self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device) self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0 self.accum_count = 0
# Reference model is initialized from policy model. # Reference model is initialized from policy model.
@ -83,7 +84,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss() self.policy_loss_fn = PolicyLoss()
self.global_step = 0 self.global_step = 0
if use_wandb and self.rank == 0: if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True) self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
def setup(self): def setup(self):
super().setup() super().setup()
@ -109,6 +110,7 @@ class GRPOConsumer(BaseConsumer):
action_mask = data["action_mask"] action_mask = data["action_mask"]
num_action = action_mask.shape[1] num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"] old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
need_update = (step_idx + 1) % self.num_microbatches == 0 need_update = (step_idx + 1) % self.num_microbatches == 0
@ -168,6 +170,7 @@ class GRPOConsumer(BaseConsumer):
format_reward = all_reduce_mean(format_reward.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
# Calculate accumulate value. # Calculate accumulate value.
self.accum_loss.add_(loss.data) self.accum_loss.add_(loss.data)
self.accum_reward.add_(reward.data) self.accum_reward.add_(reward.data)
@ -175,6 +178,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_format_reward.add_(format_reward.data) self.accum_format_reward.add_(format_reward.data)
self.accum_acc_reward.add_(acc_reward.data) self.accum_acc_reward.add_(acc_reward.data)
self.accum_advantages.add_(advantages.data) self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data)
self.accum_count += 1 self.accum_count += 1
if need_update: if need_update:
self.optimizer.step() self.optimizer.step()
@ -184,32 +188,38 @@ class GRPOConsumer(BaseConsumer):
print( print(
"Loss:", "Loss:",
self.accum_loss.item() / self.accum_count, self.accum_loss.item() / self.accum_count,
"Reward:", "\nReward:",
self.accum_reward.item() / self.accum_count, self.accum_reward.item() / self.accum_count,
"KL:", "\nFormat Reward:",
self.accum_kl.item() / self.accum_count,
"Format Reward:",
self.accum_format_reward.item() / self.accum_count, self.accum_format_reward.item() / self.accum_count,
"Acc Reward:", "\nAcc Reward:",
self.accum_acc_reward.item() / self.accum_count, self.accum_acc_reward.item() / self.accum_count,
"Advantages:", "\nKL:",
self.accum_kl.item() / self.accum_count,
"\nAdvantages:",
self.accum_advantages.item() / self.accum_count, self.accum_advantages.item() / self.accum_count,
"\nResponse Length:",
self.accum_response_length.item() / self.accum_count,
) )
self.wandb_run.log( self.wandb_run.log(
{ {
"train/loss": self.accum_loss.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count,
"train/reward": self.accum_reward.item() / self.accum_count, "train/reward": self.accum_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/format_reward": self.accum_format_reward.item() / self.accum_count, "train/format_reward": self.accum_format_reward.item() / self.accum_count,
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count, "train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count,
"train/response_length": self.accum_response_length.item() / self.accum_count,
} }
) )
self.accum_loss.zero_() self.accum_loss.zero_()
self.accum_reward.zero_() self.accum_reward.zero_()
self.accum_kl.zero_()
self.accum_acc_reward.zero_() self.accum_acc_reward.zero_()
self.accum_format_reward.zero_() self.accum_format_reward.zero_()
self.accum_kl.zero_()
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0 self.accum_count = 0
return loss_scalar return loss_scalar