From b96d69055e693d8c84d62a00dae896e50d3a7e60 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 10:51:27 +0800 Subject: [PATCH] grpo consumer --- .../coati/distributed/grpo_consumer.py | 56 +++++++++---------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 2f230f5ed..49240d8da 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -52,10 +52,11 @@ class GRPOConsumer(BaseConsumer): self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4) + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6) self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) + self.accum_count = 0 # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -79,13 +80,7 @@ class GRPOConsumer(BaseConsumer): self.policy_loss_fn = PolicyLoss() self.global_step = 0 if self.rank == 0: - self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True) - # import os - # import time - - # log_dir = self.wandb_run.dir - # # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) - # # self.writer = SummaryWriter(log_dir=log_dir) + self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True) def setup(self): super().setup() @@ -129,15 +124,16 @@ class GRPOConsumer(BaseConsumer): )["logits"] reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) - # GRPO advantage calculation - kl = torch.sum(-0.1 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( - action_mask, dim=-1 + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 ) + kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1) reward = self.reward_model( data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] ) - reward = kl + reward # [batch_size, num_generations] group_reward = reward.view(-1, self.num_generations) @@ -145,50 +141,50 @@ class GRPOConsumer(BaseConsumer): reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) # [batch_size x num_generations] - advantages = (group_reward.view(-1) - reward_mean) / (reward_std + 1e-4) - - # GRPO advantage calculation - kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( - action_mask, dim=-1 - ) + advantages = (reward - reward_mean) / (reward_std + 1e-4) # Calculate Loss loss, skip_update, _ = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, action_mask, ) - loss = loss / self.num_microbatches if not skip_update: self.booster.backward(loss, self.optimizer) - loss = all_reduce_mean(loss) - reward = all_reduce_mean(reward.mean()) - kl = all_reduce_mean(kl.mean()) + loss = all_reduce_mean(loss, self.plugin) + reward = all_reduce_mean(reward.mean(), self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) self.accum_loss.add_(loss.data) self.accum_reward.add_(reward.data) self.accum_kl.add_(kl.data) + self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() loss_scalar = self.accum_loss.item() if self.rank == 0: - print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item()) + print( + "Loss:", + self.accum_loss.item() / self.accum_count, + "Reward:", + self.accum_reward.item() / self.accum_count, + "KL:", + self.accum_kl.item() / self.accum_count, + ) self.wandb_run.log( { - "train/loss": self.accum_loss.item(), - "train/reward": self.accum_reward.item(), - "train/kl": self.accum_kl.item(), + "train/loss": self.accum_loss.item() / self.accum_count, + "train/reward": self.accum_reward.item() / self.accum_count, + "train/kl": self.accum_kl.item() / self.accum_count, } ) - # self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step) - # self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step) - # self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step) - # self.global_step += 1 self.accum_loss.zero_() self.accum_reward.zero_() self.accum_kl.zero_() + self.accum_count = 0 return loss_scalar def state_dict(self):