handle empty index (#6311)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li 2025-05-16 10:00:10 +08:00 committed by GitHub
parent aca547623f
commit ab95624915
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 36 deletions

View File

@ -114,7 +114,6 @@ class BaseConsumer:
) as pbar:
for step in pbar:
i = 0
allow_sync_model = False
for _ in range(self.num_recv_per_update):
# receive data from producers
for r in range(self.num_producers):
@ -140,7 +139,6 @@ class BaseConsumer:
else:
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
if loss is not None:
allow_sync_model = True
pbar.set_postfix({"loss": loss})
i += 1
if self.lr_scheduler is not None:
@ -154,7 +152,6 @@ class BaseConsumer:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
if allow_sync_model:
if self.pp_size > 1:
print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
@ -178,7 +175,6 @@ class BaseConsumer:
)
del state_dict
torch.cuda.empty_cache()
allow_sync_model = False
@ray.remote

View File

@ -239,7 +239,10 @@ class GRPOConsumer(BaseConsumer):
# TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
true_indices = torch.nonzero(effective_prompts_mask)
# Make sure the indices are not empty.
if true_indices.numel() > 0:
true_indices = true_indices.squeeze()
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
@ -250,6 +253,8 @@ class GRPOConsumer(BaseConsumer):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
else:
excessive_prompts_idx = torch.empty([0])
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0