remove redundant code and fix bugs

This commit is contained in:
YeAnbang
2025-05-16 14:08:23 +08:00
parent a528921944
commit 11a5854b50
5 changed files with 27 additions and 59 deletions

View File

@@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -201,10 +201,7 @@ class GRPOConsumer(BaseConsumer):
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
group_ans_acc = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
)
# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
@@ -214,37 +211,14 @@ class GRPOConsumer(BaseConsumer):
loss_mask,
action_mask[:, -1] == False,
)
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
# [minibatch_size] -> calculate the number of effective prompts
effective_prompts_mask = prompt_level_mask.any(dim=1)
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
self.effective_prompt_count += effective_prompts.item()
excessive_prompts_idx = None
self.effective_prompt_count += group_reward.size(0) * self.dp_size
mean_kl, mean_loss = [], []
if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
if excessive_prompts > 0:
excessive_prompts_per_rank = excessive_prompts // self.dp_size
# Only count excessive prompts if they are greater than 1 per rank.
# TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!"
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
@@ -460,9 +434,7 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
self.total_sample_count = all_reduce_sum(
torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin
).item()
# no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers.
sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_prompt_count = 0
self.effective_sample_count = 0
@@ -507,14 +479,9 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
if excessive_prompts_idx is not None:
# All gather excessive prompts index across DP ranks.
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
return loss_scalar, excessive_prompts_idx
return loss_scalar
else:
return None, excessive_prompts_idx
return None
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
"""