[plugin] torch ddp plugin supports sharded model checkpoint (#3775)

* [plugin] torch ddp plugin add save sharded model

* [test] fix torch ddp ckpt io test

* [test] fix torch ddp ckpt io test

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] remove debug info
This commit is contained in:
Hongxin Liu
2023-05-18 20:05:59 +08:00
committed by GitHub
parent 2703a37ac9
commit 5452df63c5
5 changed files with 86 additions and 51 deletions

View File

@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union
from typing import Optional
from typing import Optional, Union
import torch
import torch.nn as nn
@@ -84,9 +83,8 @@ class CheckpointIO(ABC):
# containing no distributed tensors, dtensor -> full tensor conversion
# should be done offline via our CLI
# the existence of index file means it is a sharded checkpoint
ckpt_path = Path(checkpoint)
index_file_exists, index_file_path = has_index_file(checkpoint)
# return the origin model instead of the unwrapped model
origin_model = model