[plugin] a workaround for zero plugins' optimizer checkpoint (#3780)

* [test] refactor torch ddp checkpoint test

* [plugin] update low level zero optim checkpoint

* [plugin] update gemini optim checkpoint
This commit is contained in:
Hongxin Liu
2023-05-19 19:42:31 +08:00
committed by GitHub
parent 60e6a154bc
commit 3c07a2846e
6 changed files with 128 additions and 82 deletions

View File

@@ -0,0 +1,21 @@
import tempfile
from contextlib import contextmanager, nullcontext
from typing import Iterator
import torch.distributed as dist
@contextmanager
def shared_tempdir() -> Iterator[str]:
"""
A temporary directory that is shared across all processes.
"""
ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
with ctx_fn() as tempdir:
try:
obj = [tempdir]
dist.broadcast_object_list(obj, src=0)
tempdir = obj[0] # use the same directory on all ranks
yield tempdir
finally:
dist.barrier()