mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
polish
This commit is contained in:
@@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean, is_rank_0
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
@@ -77,8 +77,15 @@ class GRPOConsumer(BaseConsumer):
|
||||
)
|
||||
|
||||
self.policy_loss_fn = PolicyLoss()
|
||||
if is_rank_0():
|
||||
self.run = wandb.init(project="Colossal-GRPO-Test4")
|
||||
self.global_step = 0
|
||||
if self.rank == 0:
|
||||
self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True)
|
||||
# import os
|
||||
# import time
|
||||
|
||||
# log_dir = self.wandb_run.dir
|
||||
# # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
# # self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
@@ -115,10 +122,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
)["logits"]
|
||||
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
|
||||
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
|
||||
|
||||
# GRPO advantage calculation
|
||||
@@ -126,7 +134,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
action_mask, dim=-1
|
||||
)
|
||||
|
||||
reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"])
|
||||
reward = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
reward = kl + reward
|
||||
# [batch_size, num_generations]
|
||||
group_reward = reward.view(-1, self.num_generations)
|
||||
@@ -163,11 +173,19 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if is_rank_0():
|
||||
if self.rank == 0:
|
||||
print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item())
|
||||
self.run.log(
|
||||
{"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()}
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"train/loss": self.accum_loss.item(),
|
||||
"train/reward": self.accum_reward.item(),
|
||||
"train/kl": self.accum_kl.item(),
|
||||
}
|
||||
)
|
||||
# self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step)
|
||||
# self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step)
|
||||
# self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step)
|
||||
# self.global_step += 1
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_kl.zero_()
|
||||
|
Reference in New Issue
Block a user