diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 945bb05dc..97e9a4183 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -132,19 +132,21 @@ class BaseProducer: ): self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration - print( - f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" - ) torch.cuda.empty_cache() if self.consumer_pp_size > 1: for i in range(self.consumer_pp_size): - print(f"[P{self.producer_idx}] Sync model PP stage {i}") + print( + f"[P{self.producer_idx}] Sync model PP stage {i} episode {episode} step {(i + 1) // self.num_microbatches - 1}" + ) state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name=f"sync_model_{i}" ) self.load_state_dict(state_dict) else: + print( + f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" + ) state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" )