From 807a5a43b2e4135b6295d6ef598220928c0cc202 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 23 Apr 2025 16:56:12 +0800 Subject: [PATCH] support reusing excessive samples --- .../coati/distributed/consumer.py | 11 ++++-- .../coati/distributed/grpo_consumer.py | 34 ++++++++++++++++--- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 50871e369..0cb35b94b 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -109,17 +109,22 @@ class BaseConsumer: batches = self.buffer[ self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size ] - self.buffer = self.buffer[self.dp_size * self.microbatch_size :] batch = pad_batch( batches ) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking batch = bind_batch(batches) batch = post_recv(batch) - loss = self.step(i, pbar, **batch) + loss, num_excessive_rollouts = self.step(i, pbar, **batch) + self.buffer = ( + self.buffer[ + (self.dp_rank + 1) * self.microbatch_size + - num_excessive_rollouts : (self.dp_rank + 1) * self.microbatch_size + ] + + self.buffer[self.dp_size * self.microbatch_size :] + ) if loss is not None: pbar.set_postfix({"loss": loss}) i += 1 - assert len(self.buffer) == 0 if self.lr_scheduler is not None: self.lr_scheduler.step() if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 34ac2eec8..d68860b62 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -213,7 +213,6 @@ class GRPOConsumer(BaseConsumer): action_mask[:, -1] == False, ) effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask - effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) @@ -222,9 +221,32 @@ class GRPOConsumer(BaseConsumer): mean_kl, mean_loss = [], [] - # update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out. - # balance between efficiency and accuracy need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations + # to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration. + num_excessive_samples = ( + int( + (self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations) + / self.num_generations + / self.dp_size + ) + * self.num_generations + ) + if num_excessive_samples > 0: + data = { + k: ( + v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)] + if k + in ["input_ids", "attention_mask", "action_log_probs", "action_mask", "response_idx", "gt_answer"] + else v + ) + for k, v in data.items() + } + action_mask = action_mask[: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)] + loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)] + advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)] + else: + num_excessive_samples = 0 + pbar.set_postfix( { "Step": self.global_step + 1, @@ -338,7 +360,7 @@ class GRPOConsumer(BaseConsumer): loss_mask=inputs["loss_mask"], total_effective_tokens_in_batch=total_effective_tokens_count, ) - return loss + return loss, num_excessive_samples // self.num_generations policy_model_outputs = self.booster.execute_pipeline( iter([data_policy_forward]), @@ -477,7 +499,9 @@ class GRPOConsumer(BaseConsumer): self.accum_advantages.zero_() self.accum_response_length.zero_() self.accum_count = 0 - return loss_scalar + return loss_scalar, num_excessive_samples // self.num_generations + else: + return None, num_excessive_samples // self.num_generations def state_dict(self): self.policy_model._force_wait_all_gather()