mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
move logging to producer
This commit is contained in:
@@ -34,13 +34,13 @@ class GRPOConsumer(BaseConsumer):
|
||||
plugin_config,
|
||||
minibatch_size=1,
|
||||
num_generations=8,
|
||||
use_wandb=True,
|
||||
generate_config=None,
|
||||
grpo_config={},
|
||||
project_name=None,
|
||||
save_interval: int = 100,
|
||||
save_dir="./model",
|
||||
eval_interval: int = -1,
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||
@@ -73,7 +73,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
minibatch_size,
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
eval_interval=eval_interval,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
@@ -93,6 +92,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.project_name = project_name
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
self.project_name = project_name
|
||||
self.run_name = run_name
|
||||
self.wandb_group_name = wandb_group_name
|
||||
|
||||
self.policy_loss_fn = PolicyLoss(
|
||||
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
|
||||
@@ -143,7 +145,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
**reward_model_kwargs,
|
||||
)
|
||||
self.global_step = 0
|
||||
self.use_wandb = use_wandb
|
||||
|
||||
self.lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=self.optimizer,
|
||||
@@ -154,13 +155,16 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
if self.use_wandb and (
|
||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
# Initialize wandb.
|
||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
|
||||
self.wandb_run = wandb.init(
|
||||
project=self.project_name,
|
||||
sync_tensorboard=True,
|
||||
dir="./wandb",
|
||||
name=self.run_name,
|
||||
group=self.wandb_group_name,
|
||||
)
|
||||
|
||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||
@@ -512,8 +516,8 @@ class GRPOConsumer(BaseConsumer):
|
||||
}
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
||||
|
||||
self.wandb_run.log(metrics)
|
||||
if self.wandb_run is not None:
|
||||
self.wandb_run.log(metrics)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_ans_acc.zero_()
|
||||
|
Reference in New Issue
Block a user