From bd61918dcfe7f4028f8ade4e8be351c4df991bf9 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 30 Apr 2025 21:36:11 +0800 Subject: [PATCH] reuse comm-group --- .../ColossalChat/coati/distributed/consumer.py | 12 ++++++------ .../ColossalChat/coati/distributed/producer.py | 6 ++++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 38e55a65c..31bd73e88 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -94,9 +94,6 @@ class BaseConsumer: if self.rank == 0: cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") - for i in range(self.num_producers): - cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}") - self.buffer = [] self.recv_cnt = 0 @@ -116,11 +113,14 @@ class BaseConsumer: i = 0 if self.eval_interval > 0 and step % self.eval_interval == 0: eval_statistics = None + eval_global_step = None for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}") local_eval_result = ray_broadcast_tensor_dict( - None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}" + None, src=0, device=self.device, group_name=f"sync_data_{r}" ) + assert "consumer_global_step" in local_eval_result + eval_global_step = local_eval_result.pop("consumer_global_step").item() if eval_statistics is None: eval_statistics = local_eval_result else: @@ -129,8 +129,8 @@ class BaseConsumer: } eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} if dist.get_rank() == 0: - if hasattr(self, "wandb_run") and hasattr(self, "global_step"): - self.wandb_run.log(eval_statistics, step=self.global_step) + if hasattr(self, "wandb_run"): + self.wandb_run.log(eval_statistics, step=eval_global_step) print(f"Eval statistics: {eval_statistics}") for _ in range(self.num_recv_per_update): # receive data from producers diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 529e19bf4..b7e1d8f2a 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -138,7 +138,6 @@ class BaseProducer: cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}") else: cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") - cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}") def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -194,11 +193,14 @@ class BaseProducer: # delete the file if it exists safe_write_jsonl(result_file_name, eval_results) print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}") + eval_statistics["consumer_global_step"] = torch.tensor( + [self.consumer_global_step], device=self.device + ) ray_broadcast_tensor_dict( eval_statistics, src=0, device=self.device, - group_name=f"sync_eval_statistics_{self.producer_idx}", + group_name=f"sync_data_{self.producer_idx}", ) outputs = self.rollout(**batch)