From e8cc51066be6eb9fc42eb806b47a2bb66011768f Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 30 Apr 2025 11:59:44 +0800 Subject: [PATCH] update producer --- applications/ColossalChat/coati/distributed/producer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 97e9a4183..b13747a4b 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -39,6 +39,7 @@ class BaseProducer: self.microbatch_size = microbatch_size assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size + print("Num micro batches: ", self.num_microbatches) self.dataset_config = dataset_config self.model_config = model_config @@ -135,12 +136,12 @@ class BaseProducer: torch.cuda.empty_cache() if self.consumer_pp_size > 1: - for i in range(self.consumer_pp_size): + for pp_idx in range(self.consumer_pp_size): print( - f"[P{self.producer_idx}] Sync model PP stage {i} episode {episode} step {(i + 1) // self.num_microbatches - 1}" + f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) state_dict = ray_broadcast_tensor_dict( - None, self.num_producers, device=self.device, group_name=f"sync_model_{i}" + None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" ) self.load_state_dict(state_dict) else: