reuse comm-group

This commit is contained in:
YeAnbang
2025-04-30 21:36:11 +08:00
parent 57a88395fe
commit bd61918dcf
2 changed files with 10 additions and 8 deletions

View File

@@ -94,9 +94,6 @@ class BaseConsumer:
if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}")
self.buffer = []
self.recv_cnt = 0
@@ -116,11 +113,14 @@ class BaseConsumer:
i = 0
if self.eval_interval > 0 and step % self.eval_interval == 0:
eval_statistics = None
eval_global_step = None
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
local_eval_result = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}"
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
assert "consumer_global_step" in local_eval_result
eval_global_step = local_eval_result.pop("consumer_global_step").item()
if eval_statistics is None:
eval_statistics = local_eval_result
else:
@@ -129,8 +129,8 @@ class BaseConsumer:
}
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
if dist.get_rank() == 0:
if hasattr(self, "wandb_run") and hasattr(self, "global_step"):
self.wandb_run.log(eval_statistics, step=self.global_step)
if hasattr(self, "wandb_run"):
self.wandb_run.log(eval_statistics, step=eval_global_step)
print(f"Eval statistics: {eval_statistics}")
for _ in range(self.num_recv_per_update):
# receive data from producers