From 25e0062de677e87fc02d00745addb2b77287bae2 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 24 Apr 2025 17:45:20 +0800 Subject: [PATCH] add dynamic batching control flag --- .../coati/distributed/grpo_consumer.py | 58 ++++++++++++------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index d68860b62..490c93b86 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -221,30 +221,44 @@ class GRPOConsumer(BaseConsumer): mean_kl, mean_loss = [], [] - need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations - # to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration. - num_excessive_samples = ( - int( - (self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations) - / self.num_generations - / self.dp_size - ) - * self.num_generations - ) - if num_excessive_samples > 0: - data = { - k: ( - v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)] - if k - in ["input_ids", "attention_mask", "action_log_probs", "action_mask", "response_idx", "gt_answer"] - else v + if self.grpo_config.get("dynamic_batching", True): + need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations + # to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration. + num_excessive_samples = ( + int( + (self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations) + / self.num_generations + / self.dp_size ) - for k, v in data.items() - } - action_mask = action_mask[: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)] - loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)] - advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)] + * self.num_generations + ) + if num_excessive_samples > 0: + data = { + k: ( + v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)] + if k + in [ + "input_ids", + "attention_mask", + "action_log_probs", + "action_mask", + "response_idx", + "gt_answer", + ] + else v + ) + for k, v in data.items() + } + action_mask = action_mask[ + : -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0) + ] + loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)] + advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)] + else: + num_excessive_samples = 0 else: + # If dynamic batching is disabled, we need to use all samples for training. + need_update = (step_idx + 1) % self.num_microbatches == 0 num_excessive_samples = 0 pbar.set_postfix(