[plugin] torch ddp plugin supports sharded model checkpoint (#3775)

* [plugin] torch ddp plugin add save sharded model

* [test] fix torch ddp ckpt io test

* [test] fix torch ddp ckpt io test

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] remove debug info
This commit is contained in:
Hongxin Liu 2023-05-18 20:05:59 +08:00 committed by GitHub
parent 2703a37ac9
commit 5452df63c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 51 deletions

View File

@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -50,6 +50,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint) super().save_lr_scheduler(lr_scheduler, checkpoint)
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
class TorchDDPModel(ModelWrapper): class TorchDDPModel(ModelWrapper):

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Union from typing import Optional, Union
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -84,7 +83,6 @@ class CheckpointIO(ABC):
# containing no distributed tensors, dtensor -> full tensor conversion # containing no distributed tensors, dtensor -> full tensor conversion
# should be done offline via our CLI # should be done offline via our CLI
# the existence of index file means it is a sharded checkpoint # the existence of index file means it is a sharded checkpoint
ckpt_path = Path(checkpoint)
index_file_exists, index_file_path = has_index_file(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint)
# return the origin model instead of the unwrapped model # return the origin model instead of the unwrapped model

View File

@ -1,10 +1,12 @@
# coding=utf-8 # coding=utf-8
import re
from pathlib import Path from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
from colossalai.tensor.d_tensor.d_tensor import DTensor from colossalai.tensor.d_tensor.d_tensor import DTensor
import re
SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
@ -15,6 +17,7 @@ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
# General helper functions # General helper functions
# ====================================== # ======================================
def calculate_tensor_size(tensor: torch.Tensor) -> float: def calculate_tensor_size(tensor: torch.Tensor) -> float:
""" """
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
@ -28,6 +31,7 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float:
""" """
return tensor.numel() * tensor.element_size() / 1024 / 1024 return tensor.numel() * tensor.element_size() / 1024 / 1024
def is_safetensors_available() -> bool: def is_safetensors_available() -> bool:
""" """
Check whether safetensors is available. Check whether safetensors is available.
@ -78,7 +82,6 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# Helper functions for saving shard file # Helper functions for saving shard file
# ====================================== # ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
""" """
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size. given size.
@ -107,26 +110,30 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
yield current_block, current_block_size yield current_block, current_block_size
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
""" """
load shard state dict into model load shard state dict into model
""" """
if use_safetensors and not checkpoint_file.suffix == ".safetensors": if use_safetensors and not checkpoint_file.suffix == ".safetensors":
raise Exception("load the model using `safetensors`, but no file endwith .safetensors") raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors: if use_safetensors:
from safetensors.torch import safe_open
from safetensors.torch import load_file as safe_load_file from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f: with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
if metadata["format"] != "pt": if metadata["format"] != "pt":
raise NotImplementedError( raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
)
return safe_load_file(checkpoint_file) return safe_load_file(checkpoint_file)
else: else:
return torch.load(checkpoint_file) return torch.load(checkpoint_file)
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
def load_state_dict_into_model(model: nn.Module,
state_dict: torch.Tensor,
missing_keys: List,
strict: bool = False,
load_sub_module: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. this module and its descendants.
@ -166,11 +173,12 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
if strict: if strict:
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format( error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
', '.join('"{}"'.format(k) for k in unexpected_keys)) '"{}"'.format(k) for k in unexpected_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs))) model.__class__.__name__, "\n\t".join(error_msgs)))
# ====================================== # ======================================
# Helper functions for saving state dict # Helper functions for saving state dict
# ====================================== # ======================================
@ -350,6 +358,8 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return True, index_files[0] return True, index_files[0]
else: else:
return False, None return False, None
else:
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
def load_state_dict(checkpoint_file_path: Path): def load_state_dict(checkpoint_file_path: Path):
@ -382,7 +392,6 @@ def load_state_dict(checkpoint_file_path: Path):
return torch.load(checkpoint_file_path) return torch.load(checkpoint_file_path)
def add_variant(weights_name: str, variant: Optional[str] = None) -> str: def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None and len(variant) > 0: if variant is not None and len(variant) > 0:
splits = weights_name.split(".") splits = weights_name.split(".")
@ -392,7 +401,7 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name return weights_name
def get_base_filenames(variant: str=None, use_safetensors: bool=False): def get_base_filenames(variant: str = None, use_safetensors: bool = False):
""" """
generate base weight filenames generate base weight filenames
""" """
@ -404,6 +413,7 @@ def get_base_filenames(variant: str=None, use_safetensors: bool=False):
return weights_name, save_index_file return weights_name, save_index_file
def get_shard_filename(weights_name: str, idx: int): def get_shard_filename(weights_name: str, idx: int):
""" """
get shard file name get shard file name

View File

@ -11,9 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
# These models are not compatible with AMP # These models are not compatible with AMP
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`'] _AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn']
# These models have no parameters # These models have no parameters
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] _LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch']
# These models will get stuck # These models will get stuck
_STUCK_MODELS = [ _STUCK_MODELS = [
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
@ -67,6 +67,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
skipped_models.append(name) skipped_models.append(name)
continue continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()
if err is None: if err is None:
@ -91,7 +92,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_low_level_zero_plugin(early_stop: bool = True): def test_low_level_zero_plugin(early_stop: bool = True):
spawn(run_dist, 2, early_stop=early_stop) spawn(run_dist, 4, early_stop=early_stop)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,6 +1,7 @@
import tempfile import tempfile
import torch import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
from torchvision.models import resnet18 from torchvision.models import resnet18
@ -8,12 +9,12 @@ from torchvision.models import resnet18
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPCheckpointIO
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.testing import check_state_dict_equal, rerun_if_address_is_in_use, spawn from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
def check_torch_ddp_checkpointIO(): @parameterize('shard', [True, False])
def check_torch_ddp_checkpointIO(shard: bool):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model = resnet18() model = resnet18()
@ -34,24 +35,39 @@ def check_torch_ddp_checkpointIO():
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() with tempfile.TemporaryDirectory() as tempdir:
lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile() obj = [tempdir]
ckpt_io = TorchDDPCheckpointIO() dist.broadcast_object_list(obj, src=0)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) tempdir = obj[0] # use the same directory on all ranks
ckpt_io.save_lr_scheduler(scheduler, lr_scheduler_ckpt_tempfile.name)
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
dist.barrier()
new_model = resnet18() new_model = resnet18()
new_optimizer = SGD((new_model.parameters()), lr=0.001) new_optimizer = SGD((new_model.parameters()), lr=0.001)
new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1) new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1)
_, new_optimizer, _, _, new_scheduler = booster.boost(new_model, new_optimizer, lr_scheduler=new_scheduler) new_model, new_optimizer, _, _, new_scheduler = booster.boost(new_model,
new_optimizer,
lr_scheduler=new_scheduler)
if ckpt_io.coordinator.is_master(): booster.load_model(new_model, model_ckpt_path)
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
ckpt_io.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_tempfile.name)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
dist.barrier()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')