This commit is contained in:
Tong Li
2025-02-28 10:16:42 +08:00
parent f736d747e3
commit 070907dd7f
6 changed files with 74 additions and 26 deletions

View File

@@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, is_rank_0
from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam
@@ -77,8 +77,15 @@ class GRPOConsumer(BaseConsumer):
)
self.policy_loss_fn = PolicyLoss()
if is_rank_0():
self.run = wandb.init(project="Colossal-GRPO-Test4")
self.global_step = 0
if self.rank == 0:
self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True)
# import os
# import time
# log_dir = self.wandb_run.dir
# # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
# # self.writer = SummaryWriter(log_dir=log_dir)
def setup(self):
super().setup()
@@ -115,10 +122,11 @@ class GRPOConsumer(BaseConsumer):
)["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
# GRPO advantage calculation
@@ -126,7 +134,9 @@ class GRPOConsumer(BaseConsumer):
action_mask, dim=-1
)
reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"])
reward = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = kl + reward
# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
@@ -163,11 +173,19 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.step()
self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item()
if is_rank_0():
if self.rank == 0:
print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item())
self.run.log(
{"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()}
self.wandb_run.log(
{
"train/loss": self.accum_loss.item(),
"train/reward": self.accum_reward.item(),
"train/kl": self.accum_kl.item(),
}
)
# self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step)
# self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step)
# self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step)
# self.global_step += 1
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_kl.zero_()