support logging rollouts to wandb

This commit is contained in:
YeAnbang 2025-05-16 15:56:03 +08:00
parent 203dfb1536
commit 021914c565

View File

@ -49,6 +49,7 @@ class BaseProducer:
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
wandb_log_rollout_interval: int = 20,
):
self.producer_idx = producer_idx
self.num_producers = num_producers
@ -58,7 +59,7 @@ class BaseProducer:
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
self.lastest_eval_step = -1
self.latest_eval_step = -1
self.train_dataset_config = train_dataset_config
self.model_config = model_config
@ -68,6 +69,10 @@ class BaseProducer:
self.eval_interval = eval_interval
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
self.eval_mode = False
self.wandb_rollout_data = []
self.wandb_log_rollout_interval = wandb_log_rollout_interval
self.latest_rollout_log_step = -1
if self.producer_idx == 0:
self.wandb_run = wandb.init(
project=project_name,
@ -77,7 +82,7 @@ class BaseProducer:
group=wandb_group_name,
)
if os.path.exists(self.eval_save_dir):
if os.path.exists(self.eval_save_dir) and self.eval_interval > 0:
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
# init tokenizer
@ -180,10 +185,11 @@ class BaseProducer:
break
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if (
self.consumer_global_step - self.lastest_eval_step >= self.eval_interval
and self.consumer_global_step > self.lastest_eval_step
):
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
and self.consumer_global_step > self.latest_eval_step
) or self.latest_eval_step == -1:
to_log_msg = {}
self.eval_mode = True
for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0:
print(
@ -227,7 +233,8 @@ class BaseProducer:
if self.producer_idx == 0:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.lastest_eval_step = self.consumer_global_step
self.eval_mode = False
self.latest_eval_step = self.consumer_global_step
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
@ -345,9 +352,26 @@ class SimpleProducer(BaseProducer):
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
# if self.producer_idx == 1:
# print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
if self.producer_idx == 0 and not self.eval_mode:
wandb_rollout_data = self.wandb_rollout_data + [
[
str(self.consumer_global_step),
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
]
]
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
or self.latest_rollout_log_step == -1
):
self.wandb_rollout_data = wandb_rollout_data
self.latest_rollout_log_step = self.consumer_global_step
self.wandb_run.log(
{
"rollout/rollout_examples": wandb.Table(
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
)
}
)
return rollouts
def load_state_dict(self, state_dict):