From f736d747e3b3c2601a15d240db669607e8aacba9 Mon Sep 17 00:00:00 2001
From: Tong Li <tong.li352711588@gmail.com>
Date: Tue, 25 Feb 2025 18:12:04 +0800
Subject: [PATCH] update grpo

---
 .../coati/distributed/grpo_consumer.py        | 70 +++++++++++++------
 1 file changed, 50 insertions(+), 20 deletions(-)

diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
index 79128b89e..d88df2360 100644
--- a/applications/ColossalChat/coati/distributed/grpo_consumer.py
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -3,11 +3,13 @@ from typing import Optional
 
 import ray
 import torch
+import wandb
 from coati.distributed.consumer import BaseConsumer
 from coati.distributed.loss import PolicyLoss
 from coati.distributed.reward.reward_fn import math_reward_fn
 from coati.distributed.reward.verifiable_reward import VerifiableReward
 from coati.distributed.utils import calc_action_log_probs
+from coati.trainer.utils import all_reduce_mean, is_rank_0
 from transformers import AutoModelForCausalLM, AutoTokenizer
 
 from colossalai.nn.optimizer import HybridAdam
@@ -29,6 +31,8 @@ class GRPOConsumer(BaseConsumer):
         model_config,
         plugin_config,
         microbatch_size=1,
+        num_generations=4,
+        use_wandb=False,
     ):
         super().__init__(
             num_producers,
@@ -50,6 +54,8 @@ class GRPOConsumer(BaseConsumer):
         self.policy_model.gradient_checkpointing_enable()
         self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4)
         self.accum_loss = torch.zeros(1, device=self.device)
+        self.accum_reward = torch.zeros(1, device=self.device)
+        self.accum_kl = torch.zeros(1, device=self.device)
 
         # Reference model is initialized from policy model.
         self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -57,6 +63,7 @@ class GRPOConsumer(BaseConsumer):
 
         self.tokenizer = AutoTokenizer.from_pretrained(path)
         self.pad_token_id = self.tokenizer.pad_token_id
+        self.num_generations = num_generations
 
         # Initialize verifiable reward.
         response_format_tags = {
@@ -70,6 +77,8 @@ class GRPOConsumer(BaseConsumer):
         )
 
         self.policy_loss_fn = PolicyLoss()
+        if is_rank_0():
+            self.run = wandb.init(project="Colossal-GRPO-Test4")
 
     def setup(self):
         super().setup()
@@ -87,43 +96,52 @@ class GRPOConsumer(BaseConsumer):
             },
             ...]
         Format:
-            [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
+            [batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
         """
-        labels = kwargs["input_ids"].clone()
-        labels[kwargs["attention_mask"] == 0] = -100
-        kwargs["labels"] = labels
-        sequences = kwargs["input_ids"]
-        action_mask = kwargs["action_mask"]
+
+        # Reshape to [batch_size x num_of_generation, prompt_length + response_length]
+        data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
+        action_mask = data["action_mask"]
         num_action = action_mask.shape[1]
-        old_action_log_probs = kwargs["action_log_probs"]
-        assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape
+        old_action_log_probs = data["action_log_probs"]
 
         need_update = (step_idx + 1) % self.num_microbatches == 0
 
         ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
         with ctx:
             policy_model_logits = self.policy_model(
-                input_ids=kwargs["input_ids"],
-                attention_mask=kwargs["attention_mask"],
+                input_ids=data["input_ids"],
+                attention_mask=data["attention_mask"],
             )["logits"]
-            action_log_probs = calc_action_log_probs(policy_model_logits, sequences, num_action)
+            action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
 
             reference_model_logits = self.reference_model(
-                input_ids=sequences,
-                attention_mask=kwargs["attention_mask"],
+                input_ids=data["input_ids"],
+                attention_mask=data["attention_mask"],
             )["logits"]
-            reference_action_log_probs = calc_action_log_probs(reference_model_logits, sequences, num_action)
+            reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
+
+            # GRPO advantage calculation
+            kl = torch.sum(-0.1 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
+                action_mask, dim=-1
+            )
+
+            reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"])
+            reward = kl + reward
+            # [batch_size, num_generations]
+            group_reward = reward.view(-1, self.num_generations)
+
+            # [batch_size x num_generations]
+            reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
+            reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
+            # [batch_size x num_generations]
+            advantages = (group_reward.view(-1) - reward_mean) / (reward_std + 1e-4)
 
             # GRPO advantage calculation
             kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
                 action_mask, dim=-1
             )
 
-            reward = self.reward_model(sequences, gt_answer=kwargs["gt_answer"])
-            reward = reward + kl
-            mean = reward.view(-1, reward.size(0)).mean(dim=1)
-            std = reward.view(-1, reward.size(0)).std(dim=1)
-            advantages = (reward - mean) / (std + 1e-4)
             # Calculate Loss
             loss, skip_update, _ = self.policy_loss_fn(
                 action_log_probs,
@@ -133,14 +151,26 @@ class GRPOConsumer(BaseConsumer):
             )
 
             loss = loss / self.num_microbatches
-            self.accum_loss.add_(loss.data)
             if not skip_update:
                 self.booster.backward(loss, self.optimizer)
+            loss = all_reduce_mean(loss)
+            reward = all_reduce_mean(reward.mean())
+            kl = all_reduce_mean(kl.mean())
+            self.accum_loss.add_(loss.data)
+            self.accum_reward.add_(reward.data)
+            self.accum_kl.add_(kl.data)
         if need_update:
             self.optimizer.step()
             self.optimizer.zero_grad()
             loss_scalar = self.accum_loss.item()
+            if is_rank_0():
+                print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item())
+                self.run.log(
+                    {"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()}
+                )
             self.accum_loss.zero_()
+            self.accum_reward.zero_()
+            self.accum_kl.zero_()
             return loss_scalar
 
     def state_dict(self):