add overlength sample count (#6332)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li 2025-05-28 19:18:09 +08:00 committed by GitHub
parent de2ad3b206
commit c8b368c294
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -85,6 +85,8 @@ class GRPOConsumer(BaseConsumer):
self.effective_sample_count = 0
self.effective_prompt_count = 0
self.total_sample_count = 0
self.overlength_samples = 0
self.total_overlength_samples = 0
self.project_name = project_name
self.run_name = run_name
self.wandb_group_name = wandb_group_name
@ -227,10 +229,18 @@ class GRPOConsumer(BaseConsumer):
# filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
old_loss_mask = loss_mask.clone()
loss_mask = torch.logical_and(
loss_mask,
action_mask[:, -1] == False,
)
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item()
self.overlength_samples = all_reduce_sum(
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin
)
self.total_overlength_samples += self.overlength_samples.item()
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
# [minibatch_size] -> calculate the number of effective prompts
@ -484,9 +494,11 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.zero_grad()
self.global_step += 1
sample_utilization = self.effective_sample_count / self.total_sample_count
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
self.effective_prompt_count = 0
self.effective_sample_count = 0
self.total_sample_count = 0
self.total_overlength_samples = 0
loss_scalar = self.accum_loss.item()
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
@ -502,6 +514,7 @@ class GRPOConsumer(BaseConsumer):
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
f"Sample_utilization: {sample_utilization:.4f}",
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
print("\n".join(to_log_msg))
metrics = {
@ -513,6 +526,7 @@ class GRPOConsumer(BaseConsumer):
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"train/sample_utilization": sample_utilization,
"train/percentage_overlength_samples": overlength_samples_percentage,
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
if self.policy_loss_fn.beta > 0: