grpo consumer

This commit is contained in:
Tong Li 2025-03-06 10:51:27 +08:00
parent c15225bc52
commit b96d69055e

View File

@ -52,10 +52,11 @@ class GRPOConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train() self.policy_model.train()
self.policy_model.gradient_checkpointing_enable() self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4) self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
self.accum_loss = torch.zeros(1, device=self.device) self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device)
self.accum_count = 0
# Reference model is initialized from policy model. # Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -79,13 +80,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss() self.policy_loss_fn = PolicyLoss()
self.global_step = 0 self.global_step = 0
if self.rank == 0: if self.rank == 0:
self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True) self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
# import os
# import time
# log_dir = self.wandb_run.dir
# # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
# # self.writer = SummaryWriter(log_dir=log_dir)
def setup(self): def setup(self):
super().setup() super().setup()
@ -129,15 +124,16 @@ class GRPOConsumer(BaseConsumer):
)["logits"] )["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
# GRPO advantage calculation per_token_kl = (
kl = torch.sum(-0.1 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( torch.exp(reference_action_log_probs - action_log_probs)
action_mask, dim=-1 - (reference_action_log_probs - action_log_probs)
- 1
) )
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
reward = self.reward_model( reward = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
) )
reward = kl + reward
# [batch_size, num_generations] # [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations) group_reward = reward.view(-1, self.num_generations)
@ -145,50 +141,50 @@ class GRPOConsumer(BaseConsumer):
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations] # [batch_size x num_generations]
advantages = (group_reward.view(-1) - reward_mean) / (reward_std + 1e-4) advantages = (reward - reward_mean) / (reward_std + 1e-4)
# GRPO advantage calculation
kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
action_mask, dim=-1
)
# Calculate Loss # Calculate Loss
loss, skip_update, _ = self.policy_loss_fn( loss, skip_update, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,
old_action_log_probs, old_action_log_probs,
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask, action_mask,
) )
loss = loss / self.num_microbatches
if not skip_update: if not skip_update:
self.booster.backward(loss, self.optimizer) self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss) loss = all_reduce_mean(loss, self.plugin)
reward = all_reduce_mean(reward.mean()) reward = all_reduce_mean(reward.mean(), self.plugin)
kl = all_reduce_mean(kl.mean()) kl = all_reduce_mean(kl.mean(), self.plugin)
self.accum_loss.add_(loss.data) self.accum_loss.add_(loss.data)
self.accum_reward.add_(reward.data) self.accum_reward.add_(reward.data)
self.accum_kl.add_(kl.data) self.accum_kl.add_(kl.data)
self.accum_count += 1
if need_update: if need_update:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item() loss_scalar = self.accum_loss.item()
if self.rank == 0: if self.rank == 0:
print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item()) print(
"Loss:",
self.accum_loss.item() / self.accum_count,
"Reward:",
self.accum_reward.item() / self.accum_count,
"KL:",
self.accum_kl.item() / self.accum_count,
)
self.wandb_run.log( self.wandb_run.log(
{ {
"train/loss": self.accum_loss.item(), "train/loss": self.accum_loss.item() / self.accum_count,
"train/reward": self.accum_reward.item(), "train/reward": self.accum_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item(), "train/kl": self.accum_kl.item() / self.accum_count,
} }
) )
# self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step)
# self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step)
# self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step)
# self.global_step += 1
self.accum_loss.zero_() self.accum_loss.zero_()
self.accum_reward.zero_() self.accum_reward.zero_()
self.accum_kl.zero_() self.accum_kl.zero_()
self.accum_count = 0
return loss_scalar return loss_scalar
def state_dict(self): def state_dict(self):