mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[plugin] a workaround for zero plugins' optimizer checkpoint (#3780)
* [test] refactor torch ddp checkpoint test * [plugin] update low level zero optim checkpoint * [plugin] update gemini optim checkpoint
This commit is contained in:
21
tests/test_checkpoint_io/utils.py
Normal file
21
tests/test_checkpoint_io/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import tempfile
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Iterator
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
@contextmanager
|
||||
def shared_tempdir() -> Iterator[str]:
|
||||
"""
|
||||
A temporary directory that is shared across all processes.
|
||||
"""
|
||||
ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
|
||||
with ctx_fn() as tempdir:
|
||||
try:
|
||||
obj = [tempdir]
|
||||
dist.broadcast_object_list(obj, src=0)
|
||||
tempdir = obj[0] # use the same directory on all ranks
|
||||
yield tempdir
|
||||
finally:
|
||||
dist.barrier()
|
Reference in New Issue
Block a user