reuse comm-group

This commit is contained in:
YeAnbang
2025-04-30 21:36:11 +08:00
parent 57a88395fe
commit bd61918dcf
2 changed files with 10 additions and 8 deletions

View File

@@ -138,7 +138,6 @@ class BaseProducer:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
else:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}")
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
raise NotImplementedError
@@ -194,11 +193,14 @@ class BaseProducer:
# delete the file if it exists
safe_write_jsonl(result_file_name, eval_results)
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
eval_statistics["consumer_global_step"] = torch.tensor(
[self.consumer_global_step], device=self.device
)
ray_broadcast_tensor_dict(
eval_statistics,
src=0,
device=self.device,
group_name=f"sync_eval_statistics_{self.producer_idx}",
group_name=f"sync_data_{self.producer_idx}",
)
outputs = self.rollout(**batch)