fix empty tensor (#6319)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li 2025-05-20 17:41:44 +08:00 committed by YeAnbang
parent 70c3daa4ee
commit 5bbfe1567f

View File

@ -238,7 +238,7 @@ class GRPOConsumer(BaseConsumer):
true_indices = torch.nonzero(effective_prompts_mask) true_indices = torch.nonzero(effective_prompts_mask)
# Make sure the indices are not empty. # Make sure the indices are not empty.
if true_indices.numel() > 0: if true_indices.numel() > 0:
true_indices = true_indices.squeeze() true_indices = true_indices.squeeze(-1)
if excessive_prompts_per_rank <= len(true_indices): if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else: else: