mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 06:05:26 +00:00
support logging rollouts to wandb
This commit is contained in:
parent
203dfb1536
commit
021914c565
@ -49,6 +49,7 @@ class BaseProducer:
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
wandb_log_rollout_interval: int = 20,
|
||||
):
|
||||
self.producer_idx = producer_idx
|
||||
self.num_producers = num_producers
|
||||
@ -58,7 +59,7 @@ class BaseProducer:
|
||||
self.microbatch_size = microbatch_size
|
||||
assert batch_size % microbatch_size == 0
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
self.lastest_eval_step = -1
|
||||
self.latest_eval_step = -1
|
||||
|
||||
self.train_dataset_config = train_dataset_config
|
||||
self.model_config = model_config
|
||||
@ -68,6 +69,10 @@ class BaseProducer:
|
||||
self.eval_interval = eval_interval
|
||||
self.eval_save_dir = eval_save_dir
|
||||
self.consumer_global_step = 0
|
||||
self.eval_mode = False
|
||||
self.wandb_rollout_data = []
|
||||
self.wandb_log_rollout_interval = wandb_log_rollout_interval
|
||||
self.latest_rollout_log_step = -1
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run = wandb.init(
|
||||
project=project_name,
|
||||
@ -77,7 +82,7 @@ class BaseProducer:
|
||||
group=wandb_group_name,
|
||||
)
|
||||
|
||||
if os.path.exists(self.eval_save_dir):
|
||||
if os.path.exists(self.eval_save_dir) and self.eval_interval > 0:
|
||||
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
|
||||
|
||||
# init tokenizer
|
||||
@ -180,10 +185,11 @@ class BaseProducer:
|
||||
break
|
||||
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||
if (
|
||||
self.consumer_global_step - self.lastest_eval_step >= self.eval_interval
|
||||
and self.consumer_global_step > self.lastest_eval_step
|
||||
):
|
||||
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
|
||||
and self.consumer_global_step > self.latest_eval_step
|
||||
) or self.latest_eval_step == -1:
|
||||
to_log_msg = {}
|
||||
self.eval_mode = True
|
||||
for eval_task_name in self.eval_dataloaders:
|
||||
if self.producer_idx == 0:
|
||||
print(
|
||||
@ -227,7 +233,8 @@ class BaseProducer:
|
||||
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||
self.lastest_eval_step = self.consumer_global_step
|
||||
self.eval_mode = False
|
||||
self.latest_eval_step = self.consumer_global_step
|
||||
outputs = self.rollout(**batch)
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
@ -345,9 +352,26 @@ class SimpleProducer(BaseProducer):
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
# if self.producer_idx == 1:
|
||||
# print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
||||
|
||||
if self.producer_idx == 0 and not self.eval_mode:
|
||||
wandb_rollout_data = self.wandb_rollout_data + [
|
||||
[
|
||||
str(self.consumer_global_step),
|
||||
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
|
||||
]
|
||||
]
|
||||
if (
|
||||
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
|
||||
or self.latest_rollout_log_step == -1
|
||||
):
|
||||
self.wandb_rollout_data = wandb_rollout_data
|
||||
self.latest_rollout_log_step = self.consumer_global_step
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"rollout/rollout_examples": wandb.Table(
|
||||
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
|
||||
)
|
||||
}
|
||||
)
|
||||
return rollouts
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
|
Loading…
Reference in New Issue
Block a user