From ab956249155408559b3930634a6ebe21c0f60c32 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Fri, 16 May 2025 10:00:10 +0800 Subject: [PATCH 1/3] handle empty index (#6311) Co-authored-by: Tong Li --- .../coati/distributed/consumer.py | 48 +++++++++---------- .../coati/distributed/grpo_consumer.py | 25 ++++++---- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 38b31398d..b07f8d7a1 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -114,7 +114,6 @@ class BaseConsumer: ) as pbar: for step in pbar: i = 0 - allow_sync_model = False for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): @@ -140,7 +139,6 @@ class BaseConsumer: else: self.buffer = self.buffer[self.dp_size * self.minibatch_size :] if loss is not None: - allow_sync_model = True pbar.set_postfix({"loss": loss}) i += 1 if self.lr_scheduler is not None: @@ -154,31 +152,29 @@ class BaseConsumer: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: - if allow_sync_model: - if self.pp_size > 1: - print( - f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" + if self.pp_size > 1: + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" + ) + else: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + torch.cuda.empty_cache() + state_dict = self.state_dict() + if self.pp_size > 1: + if self.tp_rank == 0 and self.dp_rank == 0: + ray_broadcast_tensor_dict( + state_dict, + src=self.num_producers, + device=self.device, + group_name=f"sync_model_{self.pp_rank}", ) - else: - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") - torch.cuda.empty_cache() - state_dict = self.state_dict() - if self.pp_size > 1: - if self.tp_rank == 0 and self.dp_rank == 0: - ray_broadcast_tensor_dict( - state_dict, - src=self.num_producers, - device=self.device, - group_name=f"sync_model_{self.pp_rank}", - ) - else: - if self.rank == 0: - ray_broadcast_tensor_dict( - state_dict, src=self.num_producers, device=self.device, group_name="sync_model" - ) - del state_dict - torch.cuda.empty_cache() - allow_sync_model = False + else: + if self.rank == 0: + ray_broadcast_tensor_dict( + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" + ) + del state_dict + torch.cuda.empty_cache() @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 31b687639..72e54b0de 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -239,17 +239,22 @@ class GRPOConsumer(BaseConsumer): # TODO: customize excessive prompts calculation. if excessive_prompts_per_rank != 0: # Mask excessive prompts to False - true_indices = torch.nonzero(effective_prompts_mask).squeeze() - if excessive_prompts_per_rank <= len(true_indices): - excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] - else: - excessive_prompts_idx = true_indices - effective_prompts_mask[excessive_prompts_idx] = False + true_indices = torch.nonzero(effective_prompts_mask) + # Make sure the indices are not empty. + if true_indices.numel() > 0: + true_indices = true_indices.squeeze() + if excessive_prompts_per_rank <= len(true_indices): + excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] + else: + excessive_prompts_idx = true_indices + effective_prompts_mask[excessive_prompts_idx] = False - for mask_idx in range(len(effective_prompts_mask)): - if effective_prompts_mask[mask_idx] == False: - # Update loss mask. - loss_mask[mask_idx] = False + for mask_idx in range(len(effective_prompts_mask)): + if effective_prompts_mask[mask_idx] == False: + # Update loss mask. + loss_mask[mask_idx] = False + else: + excessive_prompts_idx = torch.empty([0]) else: # If dynamic batching is disabled, we need to use all samples for training. need_update = (step_idx + 1) % self.num_microbatches == 0 From f8bd2db33fa3b42ef8b1ddcf2c96ce6535ab672c Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Tue, 20 May 2025 09:45:56 +0800 Subject: [PATCH 2/3] add uuid to rollout log --- applications/ColossalChat/coati/distributed/launch.py | 6 +++++- applications/ColossalChat/rl_example.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 6eeb5d379..ef81bcbdd 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -56,7 +56,7 @@ def launch_distributed( eval_save_dir: Optional[str] = None, eval_generation_config: Optional[Dict[str, Any]] = None, log_rollout_interval: int = 20, - rollout_log_file: str = "./rollout_log.jsonl", + rollout_save_dir: str = "./rollout", ): if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") @@ -74,6 +74,10 @@ def launch_distributed( run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" wandb_group_name = str(uuid.uuid4()) + rollout_log_file = os.path.join( + rollout_save_dir, + f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl", + ) procs = [] for i in range(num_producers): diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 98c139f14..bfa0ab7d0 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -273,5 +273,5 @@ if __name__ == "__main__": eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_generation_config=eval_generation_config, log_rollout_interval=20, - rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"), + rollout_save_dir=args.rollout_save_dir, ) From 32afa7bf29387045999df227faea2da8eed2faee Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 20 May 2025 17:41:44 +0800 Subject: [PATCH 3/3] fix empty tensor (#6319) Co-authored-by: Tong Li --- applications/ColossalChat/coati/distributed/grpo_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 70e2201fe..eaf3521b6 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -254,7 +254,7 @@ class GRPOConsumer(BaseConsumer): true_indices = torch.nonzero(effective_prompts_mask) # Make sure the indices are not empty. if true_indices.numel() > 0: - true_indices = true_indices.squeeze() + true_indices = true_indices.squeeze(-1) if excessive_prompts_per_rank <= len(true_indices): excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] else: