mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-10 12:22:28 +00:00
add overlength sample count (#6332)
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
de2ad3b206
commit
c8b368c294
@ -85,6 +85,8 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.total_sample_count = 0
|
self.total_sample_count = 0
|
||||||
|
self.overlength_samples = 0
|
||||||
|
self.total_overlength_samples = 0
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
self.run_name = run_name
|
self.run_name = run_name
|
||||||
self.wandb_group_name = wandb_group_name
|
self.wandb_group_name = wandb_group_name
|
||||||
@ -227,10 +229,18 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
# filter out overlength samples
|
# filter out overlength samples
|
||||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||||
|
old_loss_mask = loss_mask.clone()
|
||||||
loss_mask = torch.logical_and(
|
loss_mask = torch.logical_and(
|
||||||
loss_mask,
|
loss_mask,
|
||||||
action_mask[:, -1] == False,
|
action_mask[:, -1] == False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item()
|
||||||
|
self.overlength_samples = all_reduce_sum(
|
||||||
|
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin
|
||||||
|
)
|
||||||
|
self.total_overlength_samples += self.overlength_samples.item()
|
||||||
|
|
||||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
||||||
|
|
||||||
# [minibatch_size] -> calculate the number of effective prompts
|
# [minibatch_size] -> calculate the number of effective prompts
|
||||||
@ -484,9 +494,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||||
|
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
self.total_sample_count = 0
|
self.total_sample_count = 0
|
||||||
|
self.total_overlength_samples = 0
|
||||||
loss_scalar = self.accum_loss.item()
|
loss_scalar = self.accum_loss.item()
|
||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
@ -502,6 +514,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||||
f"Sample_utilization: {sample_utilization:.4f}",
|
f"Sample_utilization: {sample_utilization:.4f}",
|
||||||
|
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
|
||||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||||
print("\n".join(to_log_msg))
|
print("\n".join(to_log_msg))
|
||||||
metrics = {
|
metrics = {
|
||||||
@ -513,6 +526,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||||
"train/sample_utilization": sample_utilization,
|
"train/sample_utilization": sample_utilization,
|
||||||
|
"train/percentage_overlength_samples": overlength_samples_percentage,
|
||||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||||
}
|
}
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user