mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[plugin] torch ddp plugin supports sharded model checkpoint (#3775)
* [plugin] torch ddp plugin add save sharded model * [test] fix torch ddp ckpt io test * [test] fix torch ddp ckpt io test * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] remove debug info
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Iterator, List, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -50,6 +50,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
|
||||
|
Reference in New Issue
Block a user