mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
grpo consumer
This commit is contained in:
parent
c15225bc52
commit
b96d69055e
@ -52,10 +52,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.policy_model.train()
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4)
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_count = 0
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
@ -79,13 +80,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_loss_fn = PolicyLoss()
|
||||
self.global_step = 0
|
||||
if self.rank == 0:
|
||||
self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True)
|
||||
# import os
|
||||
# import time
|
||||
|
||||
# log_dir = self.wandb_run.dir
|
||||
# # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
# # self.writer = SummaryWriter(log_dir=log_dir)
|
||||
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
@ -129,15 +124,16 @@ class GRPOConsumer(BaseConsumer):
|
||||
)["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
|
||||
|
||||
# GRPO advantage calculation
|
||||
kl = torch.sum(-0.1 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
|
||||
action_mask, dim=-1
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
|
||||
|
||||
reward = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
reward = kl + reward
|
||||
# [batch_size, num_generations]
|
||||
group_reward = reward.view(-1, self.num_generations)
|
||||
|
||||
@ -145,50 +141,50 @@ class GRPOConsumer(BaseConsumer):
|
||||
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
# [batch_size x num_generations]
|
||||
advantages = (group_reward.view(-1) - reward_mean) / (reward_std + 1e-4)
|
||||
|
||||
# GRPO advantage calculation
|
||||
kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
|
||||
action_mask, dim=-1
|
||||
)
|
||||
advantages = (reward - reward_mean) / (reward_std + 1e-4)
|
||||
|
||||
# Calculate Loss
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask,
|
||||
)
|
||||
|
||||
loss = loss / self.num_microbatches
|
||||
if not skip_update:
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
loss = all_reduce_mean(loss)
|
||||
reward = all_reduce_mean(reward.mean())
|
||||
kl = all_reduce_mean(kl.mean())
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
self.accum_loss.add_(loss.data)
|
||||
self.accum_reward.add_(reward.data)
|
||||
self.accum_kl.add_(kl.data)
|
||||
self.accum_count += 1
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if self.rank == 0:
|
||||
print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item())
|
||||
print(
|
||||
"Loss:",
|
||||
self.accum_loss.item() / self.accum_count,
|
||||
"Reward:",
|
||||
self.accum_reward.item() / self.accum_count,
|
||||
"KL:",
|
||||
self.accum_kl.item() / self.accum_count,
|
||||
)
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"train/loss": self.accum_loss.item(),
|
||||
"train/reward": self.accum_reward.item(),
|
||||
"train/kl": self.accum_kl.item(),
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/reward": self.accum_reward.item() / self.accum_count,
|
||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||
}
|
||||
)
|
||||
# self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step)
|
||||
# self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step)
|
||||
# self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step)
|
||||
# self.global_step += 1
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
|
||||
def state_dict(self):
|
||||
|
Loading…
Reference in New Issue
Block a user