support reusing excessive samples

This commit is contained in:
YeAnbang 2025-04-23 16:56:12 +08:00
parent ca6093a582
commit 807a5a43b2
2 changed files with 37 additions and 8 deletions

View File

@ -109,17 +109,22 @@ class BaseConsumer:
batches = self.buffer[
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
]
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
batch = pad_batch(
batches
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
batch = bind_batch(batches)
batch = post_recv(batch)
loss = self.step(i, pbar, **batch)
loss, num_excessive_rollouts = self.step(i, pbar, **batch)
self.buffer = (
self.buffer[
(self.dp_rank + 1) * self.microbatch_size
- num_excessive_rollouts : (self.dp_rank + 1) * self.microbatch_size
]
+ self.buffer[self.dp_size * self.microbatch_size :]
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:

View File

@ -213,7 +213,6 @@ class GRPOConsumer(BaseConsumer):
action_mask[:, -1] == False,
)
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
@ -222,9 +221,32 @@ class GRPOConsumer(BaseConsumer):
mean_kl, mean_loss = [], []
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
# balance between efficiency and accuracy
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
num_excessive_samples = (
int(
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
/ self.num_generations
/ self.dp_size
)
* self.num_generations
)
if num_excessive_samples > 0:
data = {
k: (
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
if k
in ["input_ids", "attention_mask", "action_log_probs", "action_mask", "response_idx", "gt_answer"]
else v
)
for k, v in data.items()
}
action_mask = action_mask[: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)]
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
else:
num_excessive_samples = 0
pbar.set_postfix(
{
"Step": self.global_step + 1,
@ -338,7 +360,7 @@ class GRPOConsumer(BaseConsumer):
loss_mask=inputs["loss_mask"],
total_effective_tokens_in_batch=total_effective_tokens_count,
)
return loss
return loss, num_excessive_samples // self.num_generations
policy_model_outputs = self.booster.execute_pipeline(
iter([data_policy_forward]),
@ -477,7 +499,9 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
return loss_scalar
return loss_scalar, num_excessive_samples // self.num_generations
else:
return None, num_excessive_samples // self.num_generations
def state_dict(self):
self.policy_model._force_wait_all_gather()