mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 23:38:04 +00:00
add dynamic batching control flag
This commit is contained in:
parent
807a5a43b2
commit
25e0062de6
@ -221,30 +221,44 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
mean_kl, mean_loss = [], []
|
mean_kl, mean_loss = [], []
|
||||||
|
|
||||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
if self.grpo_config.get("dynamic_batching", True):
|
||||||
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
|
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
||||||
num_excessive_samples = (
|
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
|
||||||
int(
|
num_excessive_samples = (
|
||||||
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
|
int(
|
||||||
/ self.num_generations
|
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
|
||||||
/ self.dp_size
|
/ self.num_generations
|
||||||
)
|
/ self.dp_size
|
||||||
* self.num_generations
|
|
||||||
)
|
|
||||||
if num_excessive_samples > 0:
|
|
||||||
data = {
|
|
||||||
k: (
|
|
||||||
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
|
|
||||||
if k
|
|
||||||
in ["input_ids", "attention_mask", "action_log_probs", "action_mask", "response_idx", "gt_answer"]
|
|
||||||
else v
|
|
||||||
)
|
)
|
||||||
for k, v in data.items()
|
* self.num_generations
|
||||||
}
|
)
|
||||||
action_mask = action_mask[: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)]
|
if num_excessive_samples > 0:
|
||||||
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
|
data = {
|
||||||
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
|
k: (
|
||||||
|
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
|
||||||
|
if k
|
||||||
|
in [
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"action_log_probs",
|
||||||
|
"action_mask",
|
||||||
|
"response_idx",
|
||||||
|
"gt_answer",
|
||||||
|
]
|
||||||
|
else v
|
||||||
|
)
|
||||||
|
for k, v in data.items()
|
||||||
|
}
|
||||||
|
action_mask = action_mask[
|
||||||
|
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
|
||||||
|
]
|
||||||
|
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
|
||||||
|
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
|
||||||
|
else:
|
||||||
|
num_excessive_samples = 0
|
||||||
else:
|
else:
|
||||||
|
# If dynamic batching is disabled, we need to use all samples for training.
|
||||||
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
num_excessive_samples = 0
|
num_excessive_samples = 0
|
||||||
|
|
||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
|
Loading…
Reference in New Issue
Block a user