merge grpo-latest

This commit is contained in:
YeAnbang
2025-05-29 18:14:43 +08:00
2 changed files with 32 additions and 21 deletions

View File

@@ -211,6 +211,17 @@ class GRPOConsumer(BaseConsumer):
loss_mask,
action_mask[:, -1] == False,
)
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False)==False:
# filter out samples with reward outside the range
# if dynamic batching is enabled, we filter out out of range groups before training
group_ans_acc_mean = ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
loss_mask = torch.logical_and(
loss_mask,
torch.logical_and(
group_ans_acc_mean > self.filter_range[0],
group_ans_acc_mean < self.filter_range[1],
),
)
self.effective_prompt_count += group_reward.size(0) * self.dp_size
mean_kl, mean_loss = [], []
@@ -229,8 +240,7 @@ class GRPOConsumer(BaseConsumer):
pbar.set_postfix(
{
"Global Step": self.global_step,
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
}
)
@@ -428,6 +438,7 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
self.effective_prompt_count = 0
self.effective_sample_count = 0
@@ -438,14 +449,12 @@ class GRPOConsumer(BaseConsumer):
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
raw_batch_reward_mean = sum(self.raw_train_batch_reward) / len(self.raw_train_batch_reward)
raw_batch_format_acc_mean = sum(self.raw_train_batch_format_acc) / len(
self.raw_train_batch_format_acc
)
raw_batch_ans_acc_mean = sum(self.raw_train_batch_ans_acc) / len(self.raw_train_batch_ans_acc)
raw_batch_response_len_mean = sum(self.raw_train_batch_response_len) / len(
self.raw_train_batch_response_len
)
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
overlength_samples_ratio = (raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item() # not an exact figure, but a close estimate
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
self.raw_train_batch_ans_acc = []
@@ -458,6 +467,7 @@ class GRPOConsumer(BaseConsumer):
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
f"Response Length: {raw_batch_response_len_mean:.4f}",
f"Sample_utilization: {sample_utilization:.4f}",
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
print("\n".join(to_log_msg))
metrics = {
@@ -469,6 +479,7 @@ class GRPOConsumer(BaseConsumer):
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"train/sample_utilization": sample_utilization,
"train/overlength_samples_ratio": overlength_samples_ratio,
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
if self.policy_loss_fn.beta > 0: