fix: allow for print for fns that are used in both dist and single

This commit is contained in:
Zach Nussbaum 2023-04-25 20:34:12 +00:00
parent 7832707c37
commit 586a8abc06

View File

@ -2,5 +2,8 @@ import torch.distributed as dist
def rank0_print(msg):
if dist.get_rank() == 0:
if dist.is_initialized():
if dist.get_rank() == 0:
print(msg)
else:
print(msg)