diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 5d4df9ea4..ba7d882c9 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -217,6 +217,7 @@ class BaseConsumer: effective_group_mask = None if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True): # filter the group based on the reward and accuracy + group_ans_acc_mean = ans_acc.mean(dim=1) effective_group_mask = torch.logical_and( group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1] ) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index e44ed9227..a2c3e03d6 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -91,9 +91,6 @@ class GRPOConsumer(BaseConsumer): self.project_name = project_name self.effective_sample_count = 0 self.effective_prompt_count = 0 - self.total_sample_count = 0 - self.overlength_samples = 0 - self.total_overlength_samples = 0 self.project_name = project_name self.run_name = run_name self.wandb_group_name = wandb_group_name @@ -207,7 +204,6 @@ class GRPOConsumer(BaseConsumer): # filter out overlength samples if self.filter_truncated_response and action_mask.size(1) == self.max_length: - old_loss_mask = loss_mask.clone() loss_mask = torch.logical_and( loss_mask, action_mask[:, -1] == False, @@ -225,15 +221,7 @@ class GRPOConsumer(BaseConsumer): group_ans_acc_mean < self.filter_range[1], ), ) - self.total_overlength_samples += self.overlength_samples.item() - - prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations) - - # [minibatch_size] -> calculate the number of effective prompts - effective_prompts_mask = prompt_level_mask.any(dim=1) - effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin) - self.effective_prompt_count += effective_prompts.item() - excessive_prompts_idx = None + self.effective_prompt_count += group_reward.size(0) * self.dp_size mean_kl, mean_loss = [], [] @@ -250,8 +238,7 @@ class GRPOConsumer(BaseConsumer): pbar.set_postfix( { "Global Step": self.global_step, - "Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}", - "Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}", + "Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples", } ) @@ -477,12 +464,10 @@ class GRPOConsumer(BaseConsumer): self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 - sample_utilization = self.effective_sample_count / self.total_sample_count - overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count + # no need to run all reduce as raw_train_batch_* are not splited across dp rank + sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations self.effective_prompt_count = 0 self.effective_sample_count = 0 - self.total_sample_count = 0 - self.total_overlength_samples = 0 loss_scalar = self.accum_loss.item() if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 @@ -545,4 +530,4 @@ class GRPOConsumer(BaseConsumer): model = self.policy_model.unwrap() state_dict = model.state_dict() state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device) - return state_dict + return state_dict \ No newline at end of file