mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-16 16:32:52 +00:00
support reusing excessive samples
This commit is contained in:
parent
ca6093a582
commit
807a5a43b2
@ -109,17 +109,22 @@ class BaseConsumer:
|
||||
batches = self.buffer[
|
||||
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
|
||||
]
|
||||
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
|
||||
batch = pad_batch(
|
||||
batches
|
||||
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss = self.step(i, pbar, **batch)
|
||||
loss, num_excessive_rollouts = self.step(i, pbar, **batch)
|
||||
self.buffer = (
|
||||
self.buffer[
|
||||
(self.dp_rank + 1) * self.microbatch_size
|
||||
- num_excessive_rollouts : (self.dp_rank + 1) * self.microbatch_size
|
||||
]
|
||||
+ self.buffer[self.dp_size * self.microbatch_size :]
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
assert len(self.buffer) == 0
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
|
||||
|
@ -213,7 +213,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||
|
||||
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||
@ -222,9 +221,32 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
||||
# balance between efficiency and accuracy
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
||||
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
|
||||
num_excessive_samples = (
|
||||
int(
|
||||
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
|
||||
/ self.num_generations
|
||||
/ self.dp_size
|
||||
)
|
||||
* self.num_generations
|
||||
)
|
||||
if num_excessive_samples > 0:
|
||||
data = {
|
||||
k: (
|
||||
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
|
||||
if k
|
||||
in ["input_ids", "attention_mask", "action_log_probs", "action_mask", "response_idx", "gt_answer"]
|
||||
else v
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
action_mask = action_mask[: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)]
|
||||
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
|
||||
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
|
||||
else:
|
||||
num_excessive_samples = 0
|
||||
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"Step": self.global_step + 1,
|
||||
@ -338,7 +360,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
loss_mask=inputs["loss_mask"],
|
||||
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||
)
|
||||
return loss
|
||||
return loss, num_excessive_samples // self.num_generations
|
||||
|
||||
policy_model_outputs = self.booster.execute_pipeline(
|
||||
iter([data_policy_forward]),
|
||||
@ -477,7 +499,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
return loss_scalar, num_excessive_samples // self.num_generations
|
||||
else:
|
||||
return None, num_excessive_samples // self.num_generations
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
|
Loading…
Reference in New Issue
Block a user