mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
reuse comm-group
This commit is contained in:
parent
57a88395fe
commit
bd61918dcf
@ -94,9 +94,6 @@ class BaseConsumer:
|
||||
if self.rank == 0:
|
||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||
|
||||
for i in range(self.num_producers):
|
||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}")
|
||||
|
||||
self.buffer = []
|
||||
self.recv_cnt = 0
|
||||
|
||||
@ -116,11 +113,14 @@ class BaseConsumer:
|
||||
i = 0
|
||||
if self.eval_interval > 0 and step % self.eval_interval == 0:
|
||||
eval_statistics = None
|
||||
eval_global_step = None
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
|
||||
local_eval_result = ray_broadcast_tensor_dict(
|
||||
None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}"
|
||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||
)
|
||||
assert "consumer_global_step" in local_eval_result
|
||||
eval_global_step = local_eval_result.pop("consumer_global_step").item()
|
||||
if eval_statistics is None:
|
||||
eval_statistics = local_eval_result
|
||||
else:
|
||||
@ -129,8 +129,8 @@ class BaseConsumer:
|
||||
}
|
||||
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
||||
if dist.get_rank() == 0:
|
||||
if hasattr(self, "wandb_run") and hasattr(self, "global_step"):
|
||||
self.wandb_run.log(eval_statistics, step=self.global_step)
|
||||
if hasattr(self, "wandb_run"):
|
||||
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
||||
print(f"Eval statistics: {eval_statistics}")
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# receive data from producers
|
||||
|
@ -138,7 +138,6 @@ class BaseProducer:
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
|
||||
else:
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}")
|
||||
|
||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@ -194,11 +193,14 @@ class BaseProducer:
|
||||
# delete the file if it exists
|
||||
safe_write_jsonl(result_file_name, eval_results)
|
||||
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
|
||||
eval_statistics["consumer_global_step"] = torch.tensor(
|
||||
[self.consumer_global_step], device=self.device
|
||||
)
|
||||
ray_broadcast_tensor_dict(
|
||||
eval_statistics,
|
||||
src=0,
|
||||
device=self.device,
|
||||
group_name=f"sync_eval_statistics_{self.producer_idx}",
|
||||
group_name=f"sync_data_{self.producer_idx}",
|
||||
)
|
||||
outputs = self.rollout(**batch)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user