mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-06 18:43:58 +00:00
add overlength sample count (#6332)
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
de2ad3b206
commit
c8b368c294
@ -85,6 +85,8 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.effective_sample_count = 0
|
||||
self.effective_prompt_count = 0
|
||||
self.total_sample_count = 0
|
||||
self.overlength_samples = 0
|
||||
self.total_overlength_samples = 0
|
||||
self.project_name = project_name
|
||||
self.run_name = run_name
|
||||
self.wandb_group_name = wandb_group_name
|
||||
@ -227,10 +229,18 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
# filter out overlength samples
|
||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||
old_loss_mask = loss_mask.clone()
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
|
||||
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item()
|
||||
self.overlength_samples = all_reduce_sum(
|
||||
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin
|
||||
)
|
||||
self.total_overlength_samples += self.overlength_samples.item()
|
||||
|
||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
||||
|
||||
# [minibatch_size] -> calculate the number of effective prompts
|
||||
@ -484,9 +494,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
|
||||
self.effective_prompt_count = 0
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
self.total_overlength_samples = 0
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
@ -502,6 +514,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||
f"Sample_utilization: {sample_utilization:.4f}",
|
||||
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
|
||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||
print("\n".join(to_log_msg))
|
||||
metrics = {
|
||||
@ -513,6 +526,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"train/sample_utilization": sample_utilization,
|
||||
"train/percentage_overlength_samples": overlength_samples_percentage,
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
|
Loading…
Reference in New Issue
Block a user