mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-13 05:35:38 +00:00
fix evaluation
This commit is contained in:
parent
1644adf684
commit
6abffb9100
@ -180,7 +180,7 @@ class BaseProducer:
|
|||||||
break
|
break
|
||||||
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||||
if (
|
if (
|
||||||
self.consumer_global_step % self.eval_interval == 0
|
self.consumer_global_step - self.lastest_eval_step >= self.eval_interval
|
||||||
and self.consumer_global_step > self.lastest_eval_step
|
and self.consumer_global_step > self.lastest_eval_step
|
||||||
):
|
):
|
||||||
to_log_msg = {}
|
to_log_msg = {}
|
||||||
@ -256,6 +256,8 @@ class BaseProducer:
|
|||||||
state_dict = ray_broadcast_tensor_dict(
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||||
)
|
)
|
||||||
|
if "consumer_global_step" in state_dict:
|
||||||
|
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
|
Loading…
Reference in New Issue
Block a user