[plugin] torch ddp plugin supports sharded model checkpoint (#3775)

* [plugin] torch ddp plugin add save sharded model

* [test] fix torch ddp ckpt io test

* [test] fix torch ddp ckpt io test

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] remove debug info
This commit is contained in:
Hongxin Liu
2023-05-18 20:05:59 +08:00
committed by GitHub
parent 2703a37ac9
commit 5452df63c5
5 changed files with 86 additions and 51 deletions

View File

@@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -50,6 +50,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
class TorchDDPModel(ModelWrapper):