diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 593e0f4ec..690a10608 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -124,17 +124,17 @@ class BaseConsumer: raw_batch_with_reward = self.calculate_reward({k:v.view(-1, v.size(-1)) if k!='temperature' else v for k, v in raw_batch.items()}) raw_batch_with_reward = {k: v.view(-1, self.num_generations, v.size(-1)) if k!='temperature' else v for k, v in raw_batch_with_reward.items()} # [batch_size, num_generations] -> [batch_size] - group_reward_mean = raw_batch_with_reward["reward"][:,:,0].mean(dim=-1) - group_format_acc_mean = raw_batch_with_reward["format_acc"][:,:,0].mean(dim=-1) - group_ans_acc_mean = raw_batch_with_reward["ans_acc"][:,:,0].mean(dim=-1) - group_response_len = ( + reward = raw_batch_with_reward["reward"][:,:,0] + format_acc = raw_batch_with_reward["format_acc"][:,:,0] + ans_acc = raw_batch_with_reward["ans_acc"][:,:,0] + response_len = ( (raw_batch_with_reward["response_idx"][:, :, 1] - raw_batch_with_reward["response_idx"][:, :, 0] + 1) .type(torch.float32) - .mean(dim=-1) ) effective_group_mask = None if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True): # filter the group based on the reward and accuracy + group_ans_acc_mean = ans_acc.mean(dim=1) effective_group_mask = torch.logical_and( group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1] ) @@ -143,15 +143,15 @@ class BaseConsumer: self.buffer.append( [ group_with_reward if effective_group_mask is None or effective_group_mask[group_idx] else None, - group_reward_mean[group_idx], - group_format_acc_mean[group_idx], - group_ans_acc_mean[group_idx], - group_response_len[group_idx], + reward[group_idx], + format_acc[group_idx], + ans_acc[group_idx], + response_len[group_idx], ] ) if effective_group_mask is not None: print( - f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" + f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" ) # mapping the effective group to the raw group for indexing effective_group_to_raw_group_mapping = {} @@ -160,7 +160,7 @@ class BaseConsumer: effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( buffer_idx ) - pbar.set_postfix({"Collect Effective Prompt": f"{len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"}) + print(f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}") while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: # on each dp_rank, we use minibatch_size effective samples to form a batch diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 50702683f..451947b44 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -211,6 +211,17 @@ class GRPOConsumer(BaseConsumer): loss_mask, action_mask[:, -1] == False, ) + if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False)==False: + # filter out samples with reward outside the range + # if dynamic batching is enabled, we filter out out of range groups before training + group_ans_acc_mean = ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1) + loss_mask = torch.logical_and( + loss_mask, + torch.logical_and( + group_ans_acc_mean > self.filter_range[0], + group_ans_acc_mean < self.filter_range[1], + ), + ) self.effective_prompt_count += group_reward.size(0) * self.dp_size mean_kl, mean_loss = [], [] @@ -229,8 +240,7 @@ class GRPOConsumer(BaseConsumer): pbar.set_postfix( { "Global Step": self.global_step, - "Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}", - "Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}", + "Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples", } ) @@ -428,6 +438,7 @@ class GRPOConsumer(BaseConsumer): self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 + # no need to run all reduce as raw_train_batch_* are not splited across dp rank sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations self.effective_prompt_count = 0 self.effective_sample_count = 0 @@ -438,14 +449,12 @@ class GRPOConsumer(BaseConsumer): if (not self.plugin.pp_size > 1 and self.rank == 0) or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): - raw_batch_reward_mean = sum(self.raw_train_batch_reward) / len(self.raw_train_batch_reward) - raw_batch_format_acc_mean = sum(self.raw_train_batch_format_acc) / len( - self.raw_train_batch_format_acc - ) - raw_batch_ans_acc_mean = sum(self.raw_train_batch_ans_acc) / len(self.raw_train_batch_ans_acc) - raw_batch_response_len_mean = sum(self.raw_train_batch_response_len) / len( - self.raw_train_batch_response_len - ) + raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item() + raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item() + raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item() + raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0) + raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item() + overlength_samples_ratio = (raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item() # not an exact figure, but a close estimate self.raw_train_batch_reward = [] self.raw_train_batch_format_acc = [] self.raw_train_batch_ans_acc = [] @@ -458,6 +467,7 @@ class GRPOConsumer(BaseConsumer): f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", f"Response Length: {raw_batch_response_len_mean:.4f}", f"Sample_utilization: {sample_utilization:.4f}", + f"Overlength samples ratio: {overlength_samples_ratio:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { @@ -469,6 +479,7 @@ class GRPOConsumer(BaseConsumer): "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/sample_utilization": sample_utilization, + "train/overlength_samples_ratio": overlength_samples_ratio, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } if self.policy_loss_fn.beta > 0: