diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 70e2201fe..eaf3521b6 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -254,7 +254,7 @@ class GRPOConsumer(BaseConsumer): true_indices = torch.nonzero(effective_prompts_mask) # Make sure the indices are not empty. if true_indices.numel() > 0: - true_indices = true_indices.squeeze() + true_indices = true_indices.squeeze(-1) if excessive_prompts_per_rank <= len(true_indices): excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] else: