From cc4faa73008fd6e7401eb1dc72a9e89e7ad2ef6f Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 16 Apr 2025 14:10:43 +0800 Subject: [PATCH] fix filtering, still buggy --- .../coati/distributed/grpo_consumer.py | 112 +++++++++--------- .../ColossalChat/coati/distributed/loss.py | 9 +- .../coati/distributed/reward/reward_fn.py | 10 +- .../distributed/reward/verifiable_reward.py | 2 - .../ColossalChat/coati/trainer/utils.py | 14 --- applications/ColossalChat/rl_example.py | 5 +- 6 files changed, 67 insertions(+), 85 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5cd97b20a..dee4e648b 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -63,8 +63,8 @@ class GRPOConsumer(BaseConsumer): self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) - self.accum_format_reward = torch.zeros(1, device=self.device) - self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_format_acc = torch.zeros(1, device=self.device) + self.accum_ans_acc = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 @@ -72,10 +72,19 @@ class GRPOConsumer(BaseConsumer): self.grpo_config = grpo_config self.project_name = project_name self.effective_sample_count = 0 + self.total_sample_count = 0 + + self.policy_loss_fn = PolicyLoss( + clip_eps_low=grpo_config.get("clip_eps_low", 0.2), + clip_eps_high=grpo_config.get("clip_eps_high", 0.2), + beta=grpo_config.get("beta", 0.01), + loss_variation=grpo_config.get("loss_variation", "sample_level"), + ) # Reference model is initialized from policy model. - self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) - self.reference_model.eval() + if self.policy_loss_fn.beta > 0: + self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.reference_model.eval() self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id @@ -108,14 +117,6 @@ class GRPOConsumer(BaseConsumer): self.reward_model = VerifiableReward( reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs ) - - self.policy_loss_fn = PolicyLoss( - clip_eps_low=grpo_config.get("clip_eps_low", 0.2), - clip_eps_high=grpo_config.get("clip_eps_high", 0.2), - skip_threshold=grpo_config.get("skip_threshold", 20.0), - beta=grpo_config.get("beta", 0.01), - loss_variation=grpo_config.get("loss_variation", "sample_level"), - ) self.global_step = 0 self.use_wandb = use_wandb @@ -139,7 +140,8 @@ class GRPOConsumer(BaseConsumer): self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler ) - self.reference_model, *_ = self.booster.boost(self.reference_model) + if self.policy_loss_fn.beta > 0: + self.reference_model, *_ = self.booster.boost(self.reference_model) self.plugin.logger.set_level("ERROR") def step(self, step_idx: int, **kwargs) -> Optional[float]: @@ -165,15 +167,14 @@ class GRPOConsumer(BaseConsumer): forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) reward_group = self.reward_model( - int(step_idx / self.num_microbatches), data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"], ) reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) - format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) - acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) + ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) # [batch_size, num_generations] @@ -186,18 +187,13 @@ class GRPOConsumer(BaseConsumer): # [batch_size x num_generations] advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), - reward_mean_no_length_penalty = ( - (format_reward + acc_reward) - .view(-1, self.num_generations) - .mean(dim=1) - .repeat_interleave(self.num_generations, dim=0) + group_ans_acc = ( + ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) ) loss_mask = ( torch.ones(action_mask.size(0), device=action_mask.device).bool() if self.filter_range is None - else torch.logical_and( - reward_mean_no_length_penalty > self.filter_range[0], reward_mean < self.filter_range[1] - ) + else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1]) ) # filter out overlength samples if self.filter_truncated_response and action_mask.size(1) == self.max_length: @@ -205,21 +201,23 @@ class GRPOConsumer(BaseConsumer): loss_mask, action_mask[:, -1] == False, ) - # for i in range(loss_mask.size(0)): - # if loss_mask[i] == False: - # print(data["input_ids"].size(), data["input_ids"][i], action_mask[i], "mean reward", reward_mean_no_length_penalty.size(), reward_mean_no_length_penalty[i]) effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) + total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() + self.total_sample_count += total_samples.item() + print( + loss_mask, + self.effective_sample_count, + self.total_sample_count, + self.batch_size * self.dp_size * self.num_generations * 0.75, + ) mean_kl, mean_loss = [], [] # update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out. # balance between efficiency and accuracy need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75 - if need_update: - print(f"***** Update gradient based on {self.effective_sample_count} valid samples *****") - self.effective_sample_count = 0 # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 ctx = ( @@ -319,7 +317,7 @@ class GRPOConsumer(BaseConsumer): per_token_kl = 0.0 kl.append(0.0) - loss, skip_update, _ = self.policy_loss_fn( + loss, _ = self.policy_loss_fn( action_log_probs, action_log_probs, inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), @@ -381,7 +379,7 @@ class GRPOConsumer(BaseConsumer): per_token_kl = 0.0 kl = None - loss, skip_update, _ = self.policy_loss_fn( + loss, _ = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), @@ -390,8 +388,7 @@ class GRPOConsumer(BaseConsumer): loss_mask=loss_mask_forward_micro_batch, ) - if not skip_update: - self.booster.backward(loss, self.optimizer) + self.booster.backward(loss, self.optimizer) loss = all_reduce_mean(loss, self.plugin) # Calculate accumulate value. if kl is not None: @@ -402,22 +399,25 @@ class GRPOConsumer(BaseConsumer): self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): reward = all_reduce_mean(reward.mean(), self.plugin) - format_reward = all_reduce_mean(format_reward.mean(), self.plugin) - acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) + format_acc = all_reduce_mean(format_acc.mean(), self.plugin) + ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_reward.add_(reward.data) - self.accum_format_reward.add_(format_reward.data) - self.accum_acc_reward.add_(acc_reward.data) + self.accum_format_acc.add_(format_acc.data) + self.accum_ans_acc.add_(ans_acc.data) self.accum_advantages.add_(advantages.data) self.accum_response_length.add_(response_length.data) self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() + sample_utilization = self.effective_sample_count / self.total_sample_count + self.effective_sample_count = 0 + self.total_sample_count = 0 if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): @@ -428,8 +428,8 @@ class GRPOConsumer(BaseConsumer): to_log_msg = ( f"Loss: {self.accum_loss.item() / self.accum_count:.4f} \ Reward: {self.accum_reward.item() / self.accum_count:.4f} \ - Format Reward: {self.accum_format_reward.item() / self.accum_count:.4f} \ - Acc Reward: {self.accum_acc_reward.item() / self.accum_count:.4f} \ + Format Reward: {self.accum_format_acc.item() / self.accum_count:.4f} \ + Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f} \ Advantages: {self.accum_advantages.item() / self.accum_count:.4f} \ Response Length: {self.accum_response_length.item() / self.accum_count:.4f}" + f" KL: {self.accum_kl.item() / self.accum_count:.4f}" @@ -439,12 +439,13 @@ class GRPOConsumer(BaseConsumer): print(to_log_msg) metrics = { "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, - "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/format_acc": self.accum_format_acc.item() / self.accum_count, + "metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count, "metrics/response_length": self.accum_response_length.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "train/sample_utilization": sample_utilization, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } if self.policy_loss_fn.beta > 0: @@ -453,12 +454,11 @@ class GRPOConsumer(BaseConsumer): self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_reward.zero_() - self.accum_acc_reward.zero_() - self.accum_format_reward.zero_() + self.accum_ans_acc.zero_() + self.accum_format_acc.zero_() self.accum_kl.zero_() self.accum_advantages.zero_() self.accum_response_length.zero_() - self.accum_count = 0 return loss_scalar @@ -507,8 +507,8 @@ class GRPOEvalConsumer(BaseConsumer): self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.accum_reward = torch.zeros(1, device=self.device) - self.accum_format_reward = torch.zeros(1, device=self.device) - self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_format_acc = torch.zeros(1, device=self.device) + self.accum_ans_acc = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = torch.zeros(1, device=self.device) @@ -545,8 +545,8 @@ class GRPOEvalConsumer(BaseConsumer): data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] ) reward = [value[0].item() for value in reward_group] - format_reward = [value[1].item() for value in reward_group] - acc_reward = [value[2].item() for value in reward_group] + format_acc = [value[1].item() for value in reward_group] + ans_acc = [value[2].item() for value in reward_group] response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))] response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True) @@ -557,8 +557,8 @@ class GRPOEvalConsumer(BaseConsumer): { "response": response[i], "reward": reward[i], - "format_reward": format_reward[i], - "acc_reward": acc_reward[i], + "format_acc": format_acc[i], + "ans_acc": ans_acc[i], "response_length": response_length[i], }, ensure_ascii=False, @@ -567,20 +567,20 @@ class GRPOEvalConsumer(BaseConsumer): ) self.accum_reward += sum(reward) - self.accum_format_reward += sum(format_reward) - self.accum_acc_reward += sum(acc_reward) + self.accum_format_acc += sum(format_acc) + self.accum_ans_acc += sum(ans_acc) self.accum_response_length += sum(response_length) self.accum_count += len(reward) # print results total_count = all_reduce_mean(self.accum_count, self.plugin) mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count - mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count - mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count + mean_format_acc = all_reduce_mean(self.accum_format_acc, self.plugin) / total_count + mean_ans_acc = all_reduce_mean(self.accum_ans_acc, self.plugin) / total_count mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count if rank == 0: print( - f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}" + f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_acc}, Mean Acc Reward: {mean_ans_acc}, Mean Response Length: {mean_response_length}" ) return None diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index bdbe64a2a..d00335db0 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -14,14 +14,12 @@ class PolicyLoss(nn.Module): self, clip_eps_low: float = 0.2, clip_eps_high: float = 0.2, - skip_threshold: float = 20.0, beta: float = 0.01, loss_variation: str = "sample_level", ) -> None: super().__init__() self.clip_eps_low = clip_eps_low self.clip_eps_high = clip_eps_high - self.skip_threshold = skip_threshold self.beta = beta self.loss_variation = loss_variation assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}" @@ -35,7 +33,6 @@ class PolicyLoss(nn.Module): action_mask: Optional[torch.Tensor] = None, loss_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - skip = False if action_mask is None: ratio = (log_probs - log_probs.detach()).exp() else: @@ -43,7 +40,7 @@ class PolicyLoss(nn.Module): surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages - if self.beta <= 0: + if self.beta == 0: # skip kl term if kl coefficient is zero per_token_kl = 0.0 loss = -torch.min(surr1, surr2) + self.beta * per_token_kl @@ -68,5 +65,7 @@ class PolicyLoss(nn.Module): loss = loss * loss_mask total_tokens = total_tokens * loss_mask loss = loss.sum() / (total_tokens.sum() + 1e-8) + else: + raise ValueError(f"Unsupported loss variation: {self.loss_variation}") - return loss, skip, ratio.max() + return loss, ratio.max() diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 1260645c9..b1ac02fcd 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -9,8 +9,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_score = 0.0 acc_score = 10.0 reward = torch.tensor(0.0) - format_reward = torch.tensor(0.0) - acc_reward = torch.tensor(0.0) + format_acc = torch.tensor(0.0) + ans_acc = torch.tensor(0.0) s, e = response_idx[0], response_idx[1] length_reward = 0.0 @@ -32,7 +32,7 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): # Check format accuracy if format_valid: - format_reward += format_score + format_acc += 1 reward += format_score # Check answer accuracy @@ -40,12 +40,12 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): final_answer is not None and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() ): - acc_reward += acc_score + ans_acc += 1 reward += acc_score reward = reward + length_reward - return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device) + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) def gsm8k_reward_fn(input_ids, **kwargs): diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index 01d8f1663..ba83f7787 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -14,7 +14,6 @@ class VerifiableReward: def __call__( self, - step: int, input_ids: torch.LongTensor, gt_answer: List[torch.Tensor] = None, response_idx: List[torch.Tensor] = None, @@ -30,7 +29,6 @@ class VerifiableReward: reward_batch = torch.stack( [ reward_fn( - step, input_ids[i], gt_answer=gt_answer[i], response_idx=response_idx[i], diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 45dfab588..5153ce3ad 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -128,20 +128,6 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor return tensor -# def all_reduce_sum(tensor: torch.Tensor, ) -> torch.Tensor: -# """ -# Performs an all-reduce operation to sum the values of the given tensor across all processes. - -# Args: -# tensor (torch.Tensor): The input tensor to be reduced. - -# Returns: -# torch.Tensor: The reduced tensor with the sum of values across all processes. -# """ -# dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) -# return tensor - - def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: """ Performs an all-reduce operation to sum the values of the given tensor across all processes. diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 0dc84045a..ccbb5b297 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -111,7 +111,7 @@ if __name__ == "__main__": # DAPO variant settings grpo_config = { - "filter_range": [0.05, 9.0], + "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch "lr": 1e-6, "train_microbatch_size": args.train_microbatch_size, "clip_eps_low": 0.2, @@ -144,8 +144,7 @@ if __name__ == "__main__": grpo_config=grpo_config, plugin_config={ "zero_stage": 2, - }, - # for zero + }, # for zero # plugin_config={ # "pp_size": 2, # "tp_size": 2,