handle empty index

This commit is contained in:
Tong Li
2025-05-15 18:30:27 +08:00
committed by YeAnbang
parent 957e3a521a
commit 1644adf684
2 changed files with 37 additions and 36 deletions

View File

@@ -245,17 +245,22 @@ class GRPOConsumer(BaseConsumer):
# TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
true_indices = torch.nonzero(effective_prompts_mask)
# Make sure the indices are not empty.
if true_indices.numel() > 0:
true_indices = true_indices.squeeze()
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
else:
excessive_prompts_idx = torch.empty([0])
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0