From 4ec73298b22315ba27ef68df4c4836ce9be0e5ee Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 15 May 2025 14:15:40 +0800 Subject: [PATCH] use consumer global step --- .../ColossalChat/coati/distributed/grpo_consumer.py | 2 -- applications/ColossalChat/coati/distributed/launch.py | 1 - applications/ColossalChat/coati/distributed/producer.py | 9 +++++++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 0b0c8a9ad..5ad73f45c 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -266,7 +266,6 @@ class GRPOConsumer(BaseConsumer): total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() self.total_sample_count += total_samples.item() - pbar.set_postfix( { "Global Step": self.global_step, @@ -522,7 +521,6 @@ class GRPOConsumer(BaseConsumer): # All gather excessive prompts index across DP ranks. excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx] excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin) - return loss_scalar, excessive_prompts_idx else: return None, excessive_prompts_idx diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index cb5b6e4e2..327db1b55 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -56,7 +56,6 @@ def launch_distributed( eval_save_dir: Optional[str] = None, eval_generation_config: Optional[Dict[str, Any]] = None, ): - if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") else: diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index fa2d1cc8d..f7c17bf56 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -58,6 +58,7 @@ class BaseProducer: self.microbatch_size = microbatch_size assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size + self.lastest_eval_step = -1 self.train_dataset_config = train_dataset_config self.model_config = model_config @@ -178,12 +179,15 @@ class BaseProducer: if i >= num_valid_microbatches: break if self.eval_interval > 0 and self.eval_dataset_config is not None: - if i % self.eval_interval == 0: + if ( + self.consumer_global_step % self.eval_interval == 0 + and self.consumer_global_step > self.lastest_eval_step + ): to_log_msg = {} for eval_task_name in self.eval_dataloaders: if self.producer_idx == 0: print( - f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}" + f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}" ) eval_results = [] eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) @@ -223,6 +227,7 @@ class BaseProducer: if self.producer_idx == 0: self.wandb_run.log(to_log_msg, step=self.consumer_global_step) + self.lastest_eval_step = self.consumer_global_step outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")