move prompt-level-filtering to buffer side

This commit is contained in:
YeAnbang 2025-05-15 18:30:32 +08:00
parent 55eee129d2
commit a528921944

View File

@ -254,7 +254,6 @@ class GRPOConsumer(BaseConsumer):
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item() self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item()
pbar.set_postfix( pbar.set_postfix(
{ {
"Global Step": self.global_step, "Global Step": self.global_step,
@ -461,6 +460,9 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.global_step += 1 self.global_step += 1
self.total_sample_count = all_reduce_sum(
torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin
).item()
sample_utilization = self.effective_sample_count / self.total_sample_count sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_prompt_count = 0 self.effective_prompt_count = 0
self.effective_sample_count = 0 self.effective_sample_count = 0
@ -564,6 +566,7 @@ class GRPOConsumer(BaseConsumer):
"format_acc": torch.Tensor, [num_of_generation] "format_acc": torch.Tensor, [num_of_generation]
"ans_acc": torch.Tensor, [num_of_generation] "ans_acc": torch.Tensor, [num_of_generation]
""" """
self.total_sample_count += rollout_group["input_ids"].size(0)
if self.filter_range is not None: if self.filter_range is not None:
# filter prompt whoes accuracy is too high or too low (out of range) # filter prompt whoes accuracy is too high or too low (out of range)
group_ans_acc = torch.mean(rollout_group["ans_acc"]) group_ans_acc = torch.mean(rollout_group["ans_acc"])