mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-31 05:49:56 +00:00 
			
		
		
		
	remove redundant code and fix bugs
This commit is contained in:
		| @@ -121,14 +121,14 @@ class BaseConsumer: | |||||||
|                             raw_batch = unbind_batch( |                             raw_batch = unbind_batch( | ||||||
|                                 ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}") |                                 ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}") | ||||||
|                             ) |                             ) | ||||||
|                             filtered_batch = [ |                             processed_batch = [ | ||||||
|                                 t |                                 self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch | ||||||
|                                 for t in [ |  | ||||||
|                                     self.prompt_level_filtering(self.calculate_group_reward(group)) |  | ||||||
|                                     for group in raw_batch |  | ||||||
|                                 ] |  | ||||||
|                                 if t is not None |  | ||||||
|                             ] |                             ] | ||||||
|  |                             filtered_batch = [t for t in processed_batch if t is not None] | ||||||
|  |                             if self.filter_range is not None: | ||||||
|  |                                 print( | ||||||
|  |                                     f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}" | ||||||
|  |                                 ) | ||||||
|  |  | ||||||
|                             self.buffer.extend(filtered_batch) |                             self.buffer.extend(filtered_batch) | ||||||
|                         while len(self.buffer) >= self.dp_size * self.minibatch_size: |                         while len(self.buffer) >= self.dp_size * self.minibatch_size: | ||||||
| @@ -137,13 +137,8 @@ class BaseConsumer: | |||||||
|                             ] |                             ] | ||||||
|                             batch = bind_batch(batches) |                             batch = bind_batch(batches) | ||||||
|                             batch = post_recv(batch) |                             batch = post_recv(batch) | ||||||
|                             loss, excessive_prompts_idx = self.step(i, pbar, **batch) |                             loss = self.step(i, pbar, **batch) | ||||||
|  |                             self.buffer = self.buffer[self.dp_size * self.minibatch_size :] | ||||||
|                             if excessive_prompts_idx is not None: |  | ||||||
|                                 excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx] |  | ||||||
|                                 self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :] |  | ||||||
|                             else: |  | ||||||
|                                 self.buffer = self.buffer[self.dp_size * self.minibatch_size :] |  | ||||||
|                             if loss is not None: |                             if loss is not None: | ||||||
|                                 allow_sync_model = True |                                 allow_sync_model = True | ||||||
|                                 pbar.set_postfix({"loss": loss}) |                                 pbar.set_postfix({"loss": loss}) | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss | |||||||
| from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn | from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn | ||||||
| from coati.distributed.reward.verifiable_reward import VerifiableReward | from coati.distributed.reward.verifiable_reward import VerifiableReward | ||||||
| from coati.distributed.utils import calc_action_log_probs | from coati.distributed.utils import calc_action_log_probs | ||||||
| from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum | from coati.trainer.utils import all_reduce_mean, all_reduce_sum | ||||||
| from transformers import AutoModelForCausalLM, AutoTokenizer | from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||
|  |  | ||||||
| from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR | from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR | ||||||
| @@ -201,10 +201,7 @@ class GRPOConsumer(BaseConsumer): | |||||||
|         reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) |         reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) | ||||||
|         # [minibatch_size x num_generations] |         # [minibatch_size x num_generations] | ||||||
|         advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) |         advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) | ||||||
|         # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), |  | ||||||
|         group_ans_acc = ( |  | ||||||
|             ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) |  | ||||||
|         ) |  | ||||||
|         # [minibatch_size x num_of_generation] |         # [minibatch_size x num_of_generation] | ||||||
|         loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() |         loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() | ||||||
|  |  | ||||||
| @@ -214,37 +211,14 @@ class GRPOConsumer(BaseConsumer): | |||||||
|                 loss_mask, |                 loss_mask, | ||||||
|                 action_mask[:, -1] == False, |                 action_mask[:, -1] == False, | ||||||
|             ) |             ) | ||||||
|         prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations) |         self.effective_prompt_count += group_reward.size(0) * self.dp_size | ||||||
|  |  | ||||||
|         # [minibatch_size] -> calculate the number of effective prompts |  | ||||||
|         effective_prompts_mask = prompt_level_mask.any(dim=1) |  | ||||||
|         effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin) |  | ||||||
|         self.effective_prompt_count += effective_prompts.item() |  | ||||||
|         excessive_prompts_idx = None |  | ||||||
|  |  | ||||||
|         mean_kl, mean_loss = [], [] |         mean_kl, mean_loss = [], [] | ||||||
|  |  | ||||||
|         if self.grpo_config.get("dynamic_batching", True): |         if self.grpo_config.get("dynamic_batching", True): | ||||||
|             need_update = self.effective_prompt_count >= self.batch_size * self.dp_size |             need_update = self.effective_prompt_count >= self.batch_size * self.dp_size | ||||||
|             excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size |             excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size | ||||||
|  |             assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!" | ||||||
|             if excessive_prompts > 0: |  | ||||||
|                 excessive_prompts_per_rank = excessive_prompts // self.dp_size |  | ||||||
|                 # Only count excessive prompts if they are greater than 1 per rank. |  | ||||||
|                 # TODO: customize excessive prompts calculation. |  | ||||||
|                 if excessive_prompts_per_rank != 0: |  | ||||||
|                     # Mask excessive prompts to False |  | ||||||
|                     true_indices = torch.nonzero(effective_prompts_mask).squeeze() |  | ||||||
|                     if excessive_prompts_per_rank <= len(true_indices): |  | ||||||
|                         excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] |  | ||||||
|                     else: |  | ||||||
|                         excessive_prompts_idx = true_indices |  | ||||||
|                     effective_prompts_mask[excessive_prompts_idx] = False |  | ||||||
|  |  | ||||||
|                     for mask_idx in range(len(effective_prompts_mask)): |  | ||||||
|                         if effective_prompts_mask[mask_idx] == False: |  | ||||||
|                             # Update loss mask. |  | ||||||
|                             loss_mask[mask_idx] = False |  | ||||||
|         else: |         else: | ||||||
|             # If dynamic batching is disabled, we need to use all samples for training. |             # If dynamic batching is disabled, we need to use all samples for training. | ||||||
|             need_update = (step_idx + 1) % self.num_microbatches == 0 |             need_update = (step_idx + 1) % self.num_microbatches == 0 | ||||||
| @@ -460,9 +434,7 @@ class GRPOConsumer(BaseConsumer): | |||||||
|             self.optimizer.step() |             self.optimizer.step() | ||||||
|             self.optimizer.zero_grad() |             self.optimizer.zero_grad() | ||||||
|             self.global_step += 1 |             self.global_step += 1 | ||||||
|             self.total_sample_count = all_reduce_sum( |             # no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers. | ||||||
|                 torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin |  | ||||||
|             ).item() |  | ||||||
|             sample_utilization = self.effective_sample_count / self.total_sample_count |             sample_utilization = self.effective_sample_count / self.total_sample_count | ||||||
|             self.effective_prompt_count = 0 |             self.effective_prompt_count = 0 | ||||||
|             self.effective_sample_count = 0 |             self.effective_sample_count = 0 | ||||||
| @@ -507,14 +479,9 @@ class GRPOConsumer(BaseConsumer): | |||||||
|                 self.accum_advantages.zero_() |                 self.accum_advantages.zero_() | ||||||
|                 self.accum_response_length.zero_() |                 self.accum_response_length.zero_() | ||||||
|                 self.accum_count = 0 |                 self.accum_count = 0 | ||||||
|  |             return loss_scalar | ||||||
|             if excessive_prompts_idx is not None: |  | ||||||
|                 # All gather excessive prompts index across DP ranks. |  | ||||||
|                 excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx] |  | ||||||
|                 excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin) |  | ||||||
|             return loss_scalar, excessive_prompts_idx |  | ||||||
|         else: |         else: | ||||||
|             return None, excessive_prompts_idx |             return None | ||||||
|  |  | ||||||
|     def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]: |     def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]: | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -66,7 +66,7 @@ def launch_distributed( | |||||||
|  |  | ||||||
|     dataset_path = train_dataset_config["path"] |     dataset_path = train_dataset_config["path"] | ||||||
|     num_samples = get_jsonl_size_fast(dataset_path) |     num_samples = get_jsonl_size_fast(dataset_path) | ||||||
|     global_inference_batch_size = inference_batch_size * num_producers  # TODO: this doesn't support TP on producer |     global_inference_batch_size = inference_batch_size * num_producers | ||||||
|     num_update_per_episode = num_samples // global_inference_batch_size |     num_update_per_episode = num_samples // global_inference_batch_size | ||||||
|     num_recv_per_update = inference_batch_size // inference_microbatch_size |     num_recv_per_update = inference_batch_size // inference_microbatch_size | ||||||
|  |  | ||||||
|   | |||||||
| @@ -187,7 +187,7 @@ class BaseProducer: | |||||||
|                         for eval_task_name in self.eval_dataloaders: |                         for eval_task_name in self.eval_dataloaders: | ||||||
|                             if self.producer_idx == 0: |                             if self.producer_idx == 0: | ||||||
|                                 print( |                                 print( | ||||||
|                                     f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}" |                                     f"[P{self.producer_idx}] Evaluate consumer step {self.consumer_global_step} on task {eval_task_name}" | ||||||
|                                 ) |                                 ) | ||||||
|                             eval_results = [] |                             eval_results = [] | ||||||
|                             eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) |                             eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) | ||||||
|   | |||||||
| @@ -104,7 +104,13 @@ if __name__ == "__main__": | |||||||
|         choices=["think_answer_tags", "boxed"], |         choices=["think_answer_tags", "boxed"], | ||||||
|         help="Reward type for GRPO.", |         help="Reward type for GRPO.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.") |     parser.add_argument( | ||||||
|  |         "-ei", | ||||||
|  |         "--eval-interval", | ||||||
|  |         type=int, | ||||||
|  |         default=100, | ||||||
|  |         help="Interval for evaluation. Evaluate every ei consumer steps.", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     # Logging/Checkpointing parameters |     # Logging/Checkpointing parameters | ||||||
|     parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") |     parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") | ||||||
| @@ -125,8 +131,8 @@ if __name__ == "__main__": | |||||||
|         and args.train_microbatch_size > 0 |         and args.train_microbatch_size > 0 | ||||||
|     ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" |     ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" | ||||||
|     assert ( |     assert ( | ||||||
|         args.train_minibatch_size <= args.train_batch_size |         args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0 | ||||||
|     ), "Train mini batch size must be less than or equals to train batch size" |     ), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size" | ||||||
|  |  | ||||||
|     if args.master_address is None: |     if args.master_address is None: | ||||||
|         # Default settings: Using single machine |         # Default settings: Using single machine | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user