move logging to producer

This commit is contained in:
YeAnbang
2025-05-14 18:10:57 +08:00
parent 47a7dc7142
commit 50070c1e84
7 changed files with 92 additions and 70 deletions

View File

@@ -34,13 +34,13 @@ class GRPOConsumer(BaseConsumer):
plugin_config,
minibatch_size=1,
num_generations=8,
use_wandb=True,
generate_config=None,
grpo_config={},
project_name=None,
save_interval: int = 100,
save_dir="./model",
eval_interval: int = -1,
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
):
print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
@@ -73,7 +73,6 @@ class GRPOConsumer(BaseConsumer):
minibatch_size,
save_interval=save_interval,
save_dir=save_dir,
eval_interval=eval_interval,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -93,6 +92,9 @@ class GRPOConsumer(BaseConsumer):
self.project_name = project_name
self.effective_sample_count = 0
self.total_sample_count = 0
self.project_name = project_name
self.run_name = run_name
self.wandb_group_name = wandb_group_name
self.policy_loss_fn = PolicyLoss(
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
@@ -143,7 +145,6 @@ class GRPOConsumer(BaseConsumer):
**reward_model_kwargs,
)
self.global_step = 0
self.use_wandb = use_wandb
self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
@@ -154,13 +155,16 @@ class GRPOConsumer(BaseConsumer):
def setup(self):
super().setup()
if self.use_wandb and (
(not self.plugin.pp_size > 1 and self.rank == 0)
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
# Initialize wandb.
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
self.wandb_run = wandb.init(
project=self.project_name,
sync_tensorboard=True,
dir="./wandb",
name=self.run_name,
group=self.wandb_group_name,
)
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
@@ -512,8 +516,8 @@ class GRPOConsumer(BaseConsumer):
}
if self.policy_loss_fn.beta > 0:
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
self.wandb_run.log(metrics)
if self.wandb_run is not None:
self.wandb_run.log(metrics)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_ans_acc.zero_()