fix reward taging bug

This commit is contained in:
YeAnbang
2025-05-03 14:34:04 +08:00
parent da867a4d8f
commit 2999bd4cc8
4 changed files with 29 additions and 32 deletions

View File

@@ -128,9 +128,8 @@ class BaseConsumer:
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
}
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
if dist.get_rank() == 0:
if hasattr(self, "wandb_run"):
self.wandb_run.log(eval_statistics, step=eval_global_step)
if hasattr(self, "wandb_run"):
self.wandb_run.log(eval_statistics, step=eval_global_step)
print(f"Eval statistics: {eval_statistics}")
for _ in range(self.num_recv_per_update):
# receive data from producers