diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 8b9452160..8b8be8e16 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -180,7 +180,7 @@ class BaseProducer: break if self.eval_interval > 0 and self.eval_dataset_config is not None: if ( - self.consumer_global_step % self.eval_interval == 0 + self.consumer_global_step - self.lastest_eval_step >= self.eval_interval and self.consumer_global_step > self.lastest_eval_step ): to_log_msg = {} @@ -256,6 +256,8 @@ class BaseProducer: state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) else: print(