From 021914c565ed5310671b974cc28b4525bf3a5a86 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 16 May 2025 15:56:03 +0800 Subject: [PATCH] support logging rollouts to wandb --- .../coati/distributed/producer.py | 42 +++++++++++++++---- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 82f57094d..0d91f43f1 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -49,6 +49,7 @@ class BaseProducer: project_name: str = None, run_name: str = None, wandb_group_name: str = None, + wandb_log_rollout_interval: int = 20, ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -58,7 +59,7 @@ class BaseProducer: self.microbatch_size = microbatch_size assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size - self.lastest_eval_step = -1 + self.latest_eval_step = -1 self.train_dataset_config = train_dataset_config self.model_config = model_config @@ -68,6 +69,10 @@ class BaseProducer: self.eval_interval = eval_interval self.eval_save_dir = eval_save_dir self.consumer_global_step = 0 + self.eval_mode = False + self.wandb_rollout_data = [] + self.wandb_log_rollout_interval = wandb_log_rollout_interval + self.latest_rollout_log_step = -1 if self.producer_idx == 0: self.wandb_run = wandb.init( project=project_name, @@ -77,7 +82,7 @@ class BaseProducer: group=wandb_group_name, ) - if os.path.exists(self.eval_save_dir): + if os.path.exists(self.eval_save_dir) and self.eval_interval > 0: raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.") # init tokenizer @@ -180,10 +185,11 @@ class BaseProducer: break if self.eval_interval > 0 and self.eval_dataset_config is not None: if ( - self.consumer_global_step - self.lastest_eval_step >= self.eval_interval - and self.consumer_global_step > self.lastest_eval_step - ): + self.consumer_global_step - self.latest_eval_step >= self.eval_interval + and self.consumer_global_step > self.latest_eval_step + ) or self.latest_eval_step == -1: to_log_msg = {} + self.eval_mode = True for eval_task_name in self.eval_dataloaders: if self.producer_idx == 0: print( @@ -227,7 +233,8 @@ class BaseProducer: if self.producer_idx == 0: self.wandb_run.log(to_log_msg, step=self.consumer_global_step) - self.lastest_eval_step = self.consumer_global_step + self.eval_mode = False + self.latest_eval_step = self.consumer_global_step outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") @@ -345,9 +352,26 @@ class SimpleProducer(BaseProducer): @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): rollouts = self.model.generate(input_ids, attention_mask, **kwargs) - # if self.producer_idx == 1: - # print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) - + if self.producer_idx == 0 and not self.eval_mode: + wandb_rollout_data = self.wandb_rollout_data + [ + [ + str(self.consumer_global_step), + str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)), + ] + ] + if ( + self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval + or self.latest_rollout_log_step == -1 + ): + self.wandb_rollout_data = wandb_rollout_data + self.latest_rollout_log_step = self.consumer_global_step + self.wandb_run.log( + { + "rollout/rollout_examples": wandb.Table( + columns=["train_step", "rollout_examples"], data=wandb_rollout_data + ) + } + ) return rollouts def load_state_dict(self, state_dict):