mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
fix logging rollouts
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -49,7 +50,8 @@ class BaseProducer:
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
wandb_log_rollout_interval: int = 20,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
):
|
||||
self.producer_idx = producer_idx
|
||||
self.num_producers = num_producers
|
||||
@@ -70,9 +72,16 @@ class BaseProducer:
|
||||
self.eval_save_dir = eval_save_dir
|
||||
self.consumer_global_step = 0
|
||||
self.eval_mode = False
|
||||
self.wandb_rollout_data = []
|
||||
self.wandb_log_rollout_interval = wandb_log_rollout_interval
|
||||
self.log_rollout_interval = log_rollout_interval
|
||||
self.latest_rollout_log_step = -1
|
||||
if producer_idx == 0:
|
||||
if os.path.exists(rollout_log_file):
|
||||
raise ValueError(
|
||||
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
|
||||
)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
|
||||
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run = wandb.init(
|
||||
project=project_name,
|
||||
@@ -320,6 +329,8 @@ class SimpleProducer(BaseProducer):
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
):
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
@@ -342,6 +353,8 @@ class SimpleProducer(BaseProducer):
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
@@ -353,26 +366,31 @@ class SimpleProducer(BaseProducer):
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
if self.producer_idx == 0 and not self.eval_mode:
|
||||
wandb_rollout_data = self.wandb_rollout_data + [
|
||||
[
|
||||
str(self.consumer_global_step),
|
||||
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
|
||||
]
|
||||
]
|
||||
if (
|
||||
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
|
||||
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
|
||||
or self.latest_rollout_log_step == -1
|
||||
):
|
||||
self.wandb_rollout_data = wandb_rollout_data
|
||||
self.latest_rollout_log_step = self.consumer_global_step
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"rollout/rollout_examples": wandb.Table(
|
||||
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
|
||||
new_record = (
|
||||
json.dumps(
|
||||
{
|
||||
"train_step": self.consumer_global_step,
|
||||
"rollout": self.tokenizer.batch_decode(
|
||||
rollouts["input_ids"][:, 0], skip_special_tokens=True
|
||||
),
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
self.rollout_log_file.write(new_record)
|
||||
self.rollout_log_file.flush()
|
||||
self.latest_rollout_log_step = self.consumer_global_step
|
||||
return rollouts
|
||||
|
||||
def __del__(self):
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.finish()
|
||||
if hasattr(self, "rollout_log_file"):
|
||||
self.rollout_log_file.close()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
Reference in New Issue
Block a user