mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-24 10:41:07 +00:00
add overlength sample count (#6332)
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
60510010d1
commit
a246bf25c3
@ -84,6 +84,12 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.project_name = project_name
|
||||
self.effective_sample_count = 0
|
||||
self.effective_prompt_count = 0
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
self.total_sample_count = 0
|
||||
self.overlength_samples = 0
|
||||
self.total_overlength_samples = 0
|
||||
>>>>>>> c8b368c2 (add overlength sample count (#6332))
|
||||
self.project_name = project_name
|
||||
self.run_name = run_name
|
||||
self.wandb_group_name = wandb_group_name
|
||||
@ -207,11 +213,25 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
# filter out overlength samples
|
||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||
old_loss_mask = loss_mask.clone()
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
self.effective_prompt_count += group_reward.size(0) * self.dp_size
|
||||
|
||||
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item()
|
||||
self.overlength_samples = all_reduce_sum(
|
||||
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin
|
||||
)
|
||||
self.total_overlength_samples += self.overlength_samples.item()
|
||||
|
||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
||||
|
||||
# [minibatch_size] -> calculate the number of effective prompts
|
||||
effective_prompts_mask = prompt_level_mask.any(dim=1)
|
||||
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
|
||||
self.effective_prompt_count += effective_prompts.item()
|
||||
excessive_prompts_idx = None
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
@ -428,9 +448,18 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
<<<<<<< HEAD
|
||||
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
||||
self.effective_prompt_count = 0
|
||||
self.effective_sample_count = 0
|
||||
=======
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
|
||||
self.effective_prompt_count = 0
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
self.total_overlength_samples = 0
|
||||
>>>>>>> c8b368c2 (add overlength sample count (#6332))
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
@ -458,6 +487,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||
f"Sample_utilization: {sample_utilization:.4f}",
|
||||
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
|
||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||
print("\n".join(to_log_msg))
|
||||
metrics = {
|
||||
@ -469,6 +499,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"train/sample_utilization": sample_utilization,
|
||||
"train/percentage_overlength_samples": overlength_samples_percentage,
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
|
Loading…
Reference in New Issue
Block a user