ColossalAI/tests/test_checkpoint_io/utils.py
Hongxin Liu 3c07a2846e
[plugin] a workaround for zero plugins' optimizer checkpoint (#3780)
* [test] refactor torch ddp checkpoint test

* [plugin] update low level zero optim checkpoint

* [plugin] update gemini optim checkpoint
2023-05-19 19:42:31 +08:00

22 lines
609 B
Python

import tempfile
from contextlib import contextmanager, nullcontext
from typing import Iterator
import torch.distributed as dist
@contextmanager
def shared_tempdir() -> Iterator[str]:
"""
A temporary directory that is shared across all processes.
"""
ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
with ctx_fn() as tempdir:
try:
obj = [tempdir]
dist.broadcast_object_list(obj, src=0)
tempdir = obj[0] # use the same directory on all ranks
yield tempdir
finally:
dist.barrier()