mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 06:29:09 +00:00
add response length
This commit is contained in:
parent
abca66e69f
commit
47d6493778
@ -59,6 +59,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_format_reward = torch.zeros(1, device=self.device)
|
self.accum_format_reward = torch.zeros(1, device=self.device)
|
||||||
self.accum_acc_reward = torch.zeros(1, device=self.device)
|
self.accum_acc_reward = torch.zeros(1, device=self.device)
|
||||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
|
|
||||||
# Reference model is initialized from policy model.
|
# Reference model is initialized from policy model.
|
||||||
@ -83,7 +84,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_loss_fn = PolicyLoss()
|
self.policy_loss_fn = PolicyLoss()
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
if use_wandb and self.rank == 0:
|
if use_wandb and self.rank == 0:
|
||||||
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
|
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
@ -109,6 +110,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
action_mask = data["action_mask"]
|
action_mask = data["action_mask"]
|
||||||
num_action = action_mask.shape[1]
|
num_action = action_mask.shape[1]
|
||||||
old_action_log_probs = data["action_log_probs"]
|
old_action_log_probs = data["action_log_probs"]
|
||||||
|
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||||
|
|
||||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
|
|
||||||
@ -168,6 +170,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||||
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||||
|
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||||
# Calculate accumulate value.
|
# Calculate accumulate value.
|
||||||
self.accum_loss.add_(loss.data)
|
self.accum_loss.add_(loss.data)
|
||||||
self.accum_reward.add_(reward.data)
|
self.accum_reward.add_(reward.data)
|
||||||
@ -175,6 +178,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_format_reward.add_(format_reward.data)
|
self.accum_format_reward.add_(format_reward.data)
|
||||||
self.accum_acc_reward.add_(acc_reward.data)
|
self.accum_acc_reward.add_(acc_reward.data)
|
||||||
self.accum_advantages.add_(advantages.data)
|
self.accum_advantages.add_(advantages.data)
|
||||||
|
self.accum_response_length.add_(response_length.data)
|
||||||
self.accum_count += 1
|
self.accum_count += 1
|
||||||
if need_update:
|
if need_update:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
@ -184,32 +188,38 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
print(
|
print(
|
||||||
"Loss:",
|
"Loss:",
|
||||||
self.accum_loss.item() / self.accum_count,
|
self.accum_loss.item() / self.accum_count,
|
||||||
"Reward:",
|
"\nReward:",
|
||||||
self.accum_reward.item() / self.accum_count,
|
self.accum_reward.item() / self.accum_count,
|
||||||
"KL:",
|
"\nFormat Reward:",
|
||||||
self.accum_kl.item() / self.accum_count,
|
|
||||||
"Format Reward:",
|
|
||||||
self.accum_format_reward.item() / self.accum_count,
|
self.accum_format_reward.item() / self.accum_count,
|
||||||
"Acc Reward:",
|
"\nAcc Reward:",
|
||||||
self.accum_acc_reward.item() / self.accum_count,
|
self.accum_acc_reward.item() / self.accum_count,
|
||||||
"Advantages:",
|
"\nKL:",
|
||||||
|
self.accum_kl.item() / self.accum_count,
|
||||||
|
"\nAdvantages:",
|
||||||
self.accum_advantages.item() / self.accum_count,
|
self.accum_advantages.item() / self.accum_count,
|
||||||
|
"\nResponse Length:",
|
||||||
|
self.accum_response_length.item() / self.accum_count,
|
||||||
)
|
)
|
||||||
self.wandb_run.log(
|
self.wandb_run.log(
|
||||||
{
|
{
|
||||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||||
"train/reward": self.accum_reward.item() / self.accum_count,
|
"train/reward": self.accum_reward.item() / self.accum_count,
|
||||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
|
||||||
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
|
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
|
||||||
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
||||||
|
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||||
|
"train/response_length": self.accum_response_length.item() / self.accum_count,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.accum_loss.zero_()
|
self.accum_loss.zero_()
|
||||||
self.accum_reward.zero_()
|
self.accum_reward.zero_()
|
||||||
self.accum_kl.zero_()
|
|
||||||
self.accum_acc_reward.zero_()
|
self.accum_acc_reward.zero_()
|
||||||
self.accum_format_reward.zero_()
|
self.accum_format_reward.zero_()
|
||||||
|
self.accum_kl.zero_()
|
||||||
|
self.accum_advantages.zero_()
|
||||||
|
self.accum_response_length.zero_()
|
||||||
|
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
return loss_scalar
|
return loss_scalar
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user