diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 6385df2cf..531cf363c 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -121,14 +121,14 @@ class BaseConsumer: raw_batch = unbind_batch( ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}") ) - filtered_batch = [ - t - for t in [ - self.prompt_level_filtering(self.calculate_group_reward(group)) - for group in raw_batch - ] - if t is not None + processed_batch = [ + self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch ] + filtered_batch = [t for t in processed_batch if t is not None] + if self.filter_range is not None: + print( + f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}" + ) self.buffer.extend(filtered_batch) while len(self.buffer) >= self.dp_size * self.minibatch_size: @@ -137,13 +137,8 @@ class BaseConsumer: ] batch = bind_batch(batches) batch = post_recv(batch) - loss, excessive_prompts_idx = self.step(i, pbar, **batch) - - if excessive_prompts_idx is not None: - excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx] - self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :] - else: - self.buffer = self.buffer[self.dp_size * self.minibatch_size :] + loss = self.step(i, pbar, **batch) + self.buffer = self.buffer[self.dp_size * self.minibatch_size :] if loss is not None: allow_sync_model = True pbar.set_postfix({"loss": loss}) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index e709c8aed..f73f2d96f 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs -from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum +from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -201,10 +201,7 @@ class GRPOConsumer(BaseConsumer): reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) # [minibatch_size x num_generations] advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) - # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), - group_ans_acc = ( - ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) - ) + # [minibatch_size x num_of_generation] loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() @@ -214,37 +211,14 @@ class GRPOConsumer(BaseConsumer): loss_mask, action_mask[:, -1] == False, ) - prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations) - - # [minibatch_size] -> calculate the number of effective prompts - effective_prompts_mask = prompt_level_mask.any(dim=1) - effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin) - self.effective_prompt_count += effective_prompts.item() - excessive_prompts_idx = None + self.effective_prompt_count += group_reward.size(0) * self.dp_size mean_kl, mean_loss = [], [] if self.grpo_config.get("dynamic_batching", True): need_update = self.effective_prompt_count >= self.batch_size * self.dp_size excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size - - if excessive_prompts > 0: - excessive_prompts_per_rank = excessive_prompts // self.dp_size - # Only count excessive prompts if they are greater than 1 per rank. - # TODO: customize excessive prompts calculation. - if excessive_prompts_per_rank != 0: - # Mask excessive prompts to False - true_indices = torch.nonzero(effective_prompts_mask).squeeze() - if excessive_prompts_per_rank <= len(true_indices): - excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] - else: - excessive_prompts_idx = true_indices - effective_prompts_mask[excessive_prompts_idx] = False - - for mask_idx in range(len(effective_prompts_mask)): - if effective_prompts_mask[mask_idx] == False: - # Update loss mask. - loss_mask[mask_idx] = False + assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!" else: # If dynamic batching is disabled, we need to use all samples for training. need_update = (step_idx + 1) % self.num_microbatches == 0 @@ -460,9 +434,7 @@ class GRPOConsumer(BaseConsumer): self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 - self.total_sample_count = all_reduce_sum( - torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin - ).item() + # no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers. sample_utilization = self.effective_sample_count / self.total_sample_count self.effective_prompt_count = 0 self.effective_sample_count = 0 @@ -507,14 +479,9 @@ class GRPOConsumer(BaseConsumer): self.accum_advantages.zero_() self.accum_response_length.zero_() self.accum_count = 0 - - if excessive_prompts_idx is not None: - # All gather excessive prompts index across DP ranks. - excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx] - excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin) - return loss_scalar, excessive_prompts_idx + return loss_scalar else: - return None, excessive_prompts_idx + return None def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]: """ diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 327db1b55..e6367d2d1 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -66,7 +66,7 @@ def launch_distributed( dataset_path = train_dataset_config["path"] num_samples = get_jsonl_size_fast(dataset_path) - global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer + global_inference_batch_size = inference_batch_size * num_producers num_update_per_episode = num_samples // global_inference_batch_size num_recv_per_update = inference_batch_size // inference_microbatch_size diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 8b9452160..d796bff48 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -187,7 +187,7 @@ class BaseProducer: for eval_task_name in self.eval_dataloaders: if self.producer_idx == 0: print( - f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}" + f"[P{self.producer_idx}] Evaluate consumer step {self.consumer_global_step} on task {eval_task_name}" ) eval_results = [] eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 2ed9ef62c..9f89eb546 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -104,7 +104,13 @@ if __name__ == "__main__": choices=["think_answer_tags", "boxed"], help="Reward type for GRPO.", ) - parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.") + parser.add_argument( + "-ei", + "--eval-interval", + type=int, + default=100, + help="Interval for evaluation. Evaluate every ei consumer steps.", + ) # Logging/Checkpointing parameters parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") @@ -125,8 +131,8 @@ if __name__ == "__main__": and args.train_microbatch_size > 0 ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" assert ( - args.train_minibatch_size <= args.train_batch_size - ), "Train mini batch size must be less than or equals to train batch size" + args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0 + ), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size" if args.master_address is None: # Default settings: Using single machine