mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
reuse comm-group
This commit is contained in:
@@ -138,7 +138,6 @@ class BaseProducer:
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
|
||||
else:
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}")
|
||||
|
||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@@ -194,11 +193,14 @@ class BaseProducer:
|
||||
# delete the file if it exists
|
||||
safe_write_jsonl(result_file_name, eval_results)
|
||||
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
|
||||
eval_statistics["consumer_global_step"] = torch.tensor(
|
||||
[self.consumer_global_step], device=self.device
|
||||
)
|
||||
ray_broadcast_tensor_dict(
|
||||
eval_statistics,
|
||||
src=0,
|
||||
device=self.device,
|
||||
group_name=f"sync_eval_statistics_{self.producer_idx}",
|
||||
group_name=f"sync_data_{self.producer_idx}",
|
||||
)
|
||||
outputs = self.rollout(**batch)
|
||||
|
||||
|
Reference in New Issue
Block a user