fix logging rollouts

This commit is contained in:
YeAnbang
2025-05-17 21:12:58 +08:00
parent 03b41d6fb5
commit 107470a360
5 changed files with 56 additions and 24 deletions

View File

@@ -1,4 +1,5 @@
import copy
import json
import os
from typing import Any, Dict, Optional
@@ -49,7 +50,8 @@ class BaseProducer:
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
wandb_log_rollout_interval: int = 20,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
):
self.producer_idx = producer_idx
self.num_producers = num_producers
@@ -70,9 +72,16 @@ class BaseProducer:
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
self.eval_mode = False
self.wandb_rollout_data = []
self.wandb_log_rollout_interval = wandb_log_rollout_interval
self.log_rollout_interval = log_rollout_interval
self.latest_rollout_log_step = -1
if producer_idx == 0:
if os.path.exists(rollout_log_file):
raise ValueError(
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
)
else:
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
if self.producer_idx == 0:
self.wandb_run = wandb.init(
project=project_name,
@@ -320,6 +329,8 @@ class SimpleProducer(BaseProducer):
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
):
super().__init__(
producer_idx,
@@ -342,6 +353,8 @@ class SimpleProducer(BaseProducer):
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
@@ -353,26 +366,31 @@ class SimpleProducer(BaseProducer):
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 0 and not self.eval_mode:
wandb_rollout_data = self.wandb_rollout_data + [
[
str(self.consumer_global_step),
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
]
]
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
or self.latest_rollout_log_step == -1
):
self.wandb_rollout_data = wandb_rollout_data
self.latest_rollout_log_step = self.consumer_global_step
self.wandb_run.log(
{
"rollout/rollout_examples": wandb.Table(
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
new_record = (
json.dumps(
{
"train_step": self.consumer_global_step,
"rollout": self.tokenizer.batch_decode(
rollouts["input_ids"][:, 0], skip_special_tokens=True
),
}
)
}
)
+ "\n"
)
self.rollout_log_file.write(new_record)
self.rollout_log_file.flush()
self.latest_rollout_log_step = self.consumer_global_step
return rollouts
def __del__(self):
if self.producer_idx == 0:
self.wandb_run.finish()
if hasattr(self, "rollout_log_file"):
self.rollout_log_file.close()
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)