fix evaluation

This commit is contained in:
YeAnbang 2025-05-16 09:42:35 +08:00
parent 1644adf684
commit 6abffb9100

View File

@ -180,7 +180,7 @@ class BaseProducer:
break
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if (
self.consumer_global_step % self.eval_interval == 0
self.consumer_global_step - self.lastest_eval_step >= self.eval_interval
and self.consumer_global_step > self.lastest_eval_step
):
to_log_msg = {}
@ -256,6 +256,8 @@ class BaseProducer:
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
print(