add dynamic batching control flag

This commit is contained in:
YeAnbang 2025-04-24 17:45:20 +08:00
parent 807a5a43b2
commit 25e0062de6

View File

@ -221,30 +221,44 @@ class GRPOConsumer(BaseConsumer):
mean_kl, mean_loss = [], [] mean_kl, mean_loss = [], []
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations if self.grpo_config.get("dynamic_batching", True):
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration. need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
num_excessive_samples = ( # to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
int( num_excessive_samples = (
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations) int(
/ self.num_generations (self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
/ self.dp_size / self.num_generations
) / self.dp_size
* self.num_generations
)
if num_excessive_samples > 0:
data = {
k: (
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
if k
in ["input_ids", "attention_mask", "action_log_probs", "action_mask", "response_idx", "gt_answer"]
else v
) )
for k, v in data.items() * self.num_generations
} )
action_mask = action_mask[: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)] if num_excessive_samples > 0:
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)] data = {
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)] k: (
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
if k
in [
"input_ids",
"attention_mask",
"action_log_probs",
"action_mask",
"response_idx",
"gt_answer",
]
else v
)
for k, v in data.items()
}
action_mask = action_mask[
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
]
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
else:
num_excessive_samples = 0
else: else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
num_excessive_samples = 0 num_excessive_samples = 0
pbar.set_postfix( pbar.set_postfix(