mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
handle empty index
This commit is contained in:
@@ -245,17 +245,22 @@ class GRPOConsumer(BaseConsumer):
|
||||
# TODO: customize excessive prompts calculation.
|
||||
if excessive_prompts_per_rank != 0:
|
||||
# Mask excessive prompts to False
|
||||
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
|
||||
if excessive_prompts_per_rank <= len(true_indices):
|
||||
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
||||
else:
|
||||
excessive_prompts_idx = true_indices
|
||||
effective_prompts_mask[excessive_prompts_idx] = False
|
||||
true_indices = torch.nonzero(effective_prompts_mask)
|
||||
# Make sure the indices are not empty.
|
||||
if true_indices.numel() > 0:
|
||||
true_indices = true_indices.squeeze()
|
||||
if excessive_prompts_per_rank <= len(true_indices):
|
||||
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
||||
else:
|
||||
excessive_prompts_idx = true_indices
|
||||
effective_prompts_mask[excessive_prompts_idx] = False
|
||||
|
||||
for mask_idx in range(len(effective_prompts_mask)):
|
||||
if effective_prompts_mask[mask_idx] == False:
|
||||
# Update loss mask.
|
||||
loss_mask[mask_idx] = False
|
||||
for mask_idx in range(len(effective_prompts_mask)):
|
||||
if effective_prompts_mask[mask_idx] == False:
|
||||
# Update loss mask.
|
||||
loss_mask[mask_idx] = False
|
||||
else:
|
||||
excessive_prompts_idx = torch.empty([0])
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
|
Reference in New Issue
Block a user