fix missing or wrong file during rebase

This commit is contained in:
YeAnbang 2025-08-05 14:41:12 +08:00
parent 118a66fd46
commit 3746f73854
2 changed files with 6 additions and 20 deletions

View File

@ -217,6 +217,7 @@ class BaseConsumer:
effective_group_mask = None
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
# filter the group based on the reward and accuracy
group_ans_acc_mean = ans_acc.mean(dim=1)
effective_group_mask = torch.logical_and(
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
)

View File

@ -91,9 +91,6 @@ class GRPOConsumer(BaseConsumer):
self.project_name = project_name
self.effective_sample_count = 0
self.effective_prompt_count = 0
self.total_sample_count = 0
self.overlength_samples = 0
self.total_overlength_samples = 0
self.project_name = project_name
self.run_name = run_name
self.wandb_group_name = wandb_group_name
@ -207,7 +204,6 @@ class GRPOConsumer(BaseConsumer):
# filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
old_loss_mask = loss_mask.clone()
loss_mask = torch.logical_and(
loss_mask,
action_mask[:, -1] == False,
@ -225,15 +221,7 @@ class GRPOConsumer(BaseConsumer):
group_ans_acc_mean < self.filter_range[1],
),
)
self.total_overlength_samples += self.overlength_samples.item()
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
# [minibatch_size] -> calculate the number of effective prompts
effective_prompts_mask = prompt_level_mask.any(dim=1)
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
self.effective_prompt_count += effective_prompts.item()
excessive_prompts_idx = None
self.effective_prompt_count += group_reward.size(0) * self.dp_size
mean_kl, mean_loss = [], []
@ -250,8 +238,7 @@ class GRPOConsumer(BaseConsumer):
pbar.set_postfix(
{
"Global Step": self.global_step,
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
}
)
@ -477,12 +464,10 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
sample_utilization = self.effective_sample_count / self.total_sample_count
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
self.effective_prompt_count = 0
self.effective_sample_count = 0
self.total_sample_count = 0
self.total_overlength_samples = 0
loss_scalar = self.accum_loss.item()
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
@ -545,4 +530,4 @@ class GRPOConsumer(BaseConsumer):
model = self.policy_model.unwrap()
state_dict = model.state_dict()
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
return state_dict
return state_dict