mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 22:19:47 +00:00
add dynamic batching control flag
This commit is contained in:
parent
807a5a43b2
commit
25e0062de6
@ -221,30 +221,44 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
||||
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
|
||||
num_excessive_samples = (
|
||||
int(
|
||||
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
|
||||
/ self.num_generations
|
||||
/ self.dp_size
|
||||
)
|
||||
* self.num_generations
|
||||
)
|
||||
if num_excessive_samples > 0:
|
||||
data = {
|
||||
k: (
|
||||
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
|
||||
if k
|
||||
in ["input_ids", "attention_mask", "action_log_probs", "action_mask", "response_idx", "gt_answer"]
|
||||
else v
|
||||
if self.grpo_config.get("dynamic_batching", True):
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
||||
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
|
||||
num_excessive_samples = (
|
||||
int(
|
||||
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
|
||||
/ self.num_generations
|
||||
/ self.dp_size
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
action_mask = action_mask[: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)]
|
||||
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
|
||||
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
|
||||
* self.num_generations
|
||||
)
|
||||
if num_excessive_samples > 0:
|
||||
data = {
|
||||
k: (
|
||||
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
|
||||
if k
|
||||
in [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"action_log_probs",
|
||||
"action_mask",
|
||||
"response_idx",
|
||||
"gt_answer",
|
||||
]
|
||||
else v
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
action_mask = action_mask[
|
||||
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
|
||||
]
|
||||
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
|
||||
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
|
||||
else:
|
||||
num_excessive_samples = 0
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
num_excessive_samples = 0
|
||||
|
||||
pbar.set_postfix(
|
||||
|
Loading…
Reference in New Issue
Block a user