From 1644adf6843b735af84e76c3346bea2e7ceee85d Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 15 May 2025 18:30:27 +0800 Subject: [PATCH] handle empty index --- .../coati/distributed/consumer.py | 48 +++++++++---------- .../coati/distributed/grpo_consumer.py | 25 ++++++---- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index ecfd65458..816cab50a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -113,7 +113,6 @@ class BaseConsumer: ) as pbar: for step in pbar: i = 0 - allow_sync_model = False for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): @@ -139,7 +138,6 @@ class BaseConsumer: else: self.buffer = self.buffer[self.dp_size * self.minibatch_size :] if loss is not None: - allow_sync_model = True pbar.set_postfix({"loss": loss}) i += 1 if self.lr_scheduler is not None: @@ -153,31 +151,29 @@ class BaseConsumer: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: - if allow_sync_model: - if self.pp_size > 1: - print( - f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" + if self.pp_size > 1: + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" + ) + else: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + torch.cuda.empty_cache() + state_dict = self.state_dict() + if self.pp_size > 1: + if self.tp_rank == 0 and self.dp_rank == 0: + ray_broadcast_tensor_dict( + state_dict, + src=self.num_producers, + device=self.device, + group_name=f"sync_model_{self.pp_rank}", ) - else: - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") - torch.cuda.empty_cache() - state_dict = self.state_dict() - if self.pp_size > 1: - if self.tp_rank == 0 and self.dp_rank == 0: - ray_broadcast_tensor_dict( - state_dict, - src=self.num_producers, - device=self.device, - group_name=f"sync_model_{self.pp_rank}", - ) - else: - if self.rank == 0: - ray_broadcast_tensor_dict( - state_dict, src=self.num_producers, device=self.device, group_name="sync_model" - ) - del state_dict - torch.cuda.empty_cache() - allow_sync_model = False + else: + if self.rank == 0: + ray_broadcast_tensor_dict( + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" + ) + del state_dict + torch.cuda.empty_cache() @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 7ea2dbf18..eae4ff54e 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -245,17 +245,22 @@ class GRPOConsumer(BaseConsumer): # TODO: customize excessive prompts calculation. if excessive_prompts_per_rank != 0: # Mask excessive prompts to False - true_indices = torch.nonzero(effective_prompts_mask).squeeze() - if excessive_prompts_per_rank <= len(true_indices): - excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] - else: - excessive_prompts_idx = true_indices - effective_prompts_mask[excessive_prompts_idx] = False + true_indices = torch.nonzero(effective_prompts_mask) + # Make sure the indices are not empty. + if true_indices.numel() > 0: + true_indices = true_indices.squeeze() + if excessive_prompts_per_rank <= len(true_indices): + excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] + else: + excessive_prompts_idx = true_indices + effective_prompts_mask[excessive_prompts_idx] = False - for mask_idx in range(len(effective_prompts_mask)): - if effective_prompts_mask[mask_idx] == False: - # Update loss mask. - loss_mask[mask_idx] = False + for mask_idx in range(len(effective_prompts_mask)): + if effective_prompts_mask[mask_idx] == False: + # Update loss mask. + loss_mask[mask_idx] = False + else: + excessive_prompts_idx = torch.empty([0]) else: # If dynamic batching is disabled, we need to use all samples for training. need_update = (step_idx + 1) % self.num_microbatches == 0