mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 11:37:14 +00:00
[plugin] torch ddp plugin supports sharded model checkpoint (#3775)
* [plugin] torch ddp plugin add save sharded model * [test] fix torch ddp ckpt io test * [test] fix torch ddp ckpt io test * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] remove debug info
This commit is contained in:
parent
2703a37ac9
commit
5452df63c5
@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, Iterator, List, Tuple, Union
|
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -50,6 +50,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||||
|
|
||||||
|
def save_sharded_model(self,
|
||||||
|
model: nn.Module,
|
||||||
|
checkpoint_path: str,
|
||||||
|
gather_dtensor: bool = False,
|
||||||
|
variant: Optional[str] = None,
|
||||||
|
max_shard_size: int = 1024,
|
||||||
|
use_safetensors: bool = False):
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
|
||||||
|
|
||||||
|
|
||||||
class TorchDDPModel(ModelWrapper):
|
class TorchDDPModel(ModelWrapper):
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -84,7 +83,6 @@ class CheckpointIO(ABC):
|
|||||||
# containing no distributed tensors, dtensor -> full tensor conversion
|
# containing no distributed tensors, dtensor -> full tensor conversion
|
||||||
# should be done offline via our CLI
|
# should be done offline via our CLI
|
||||||
# the existence of index file means it is a sharded checkpoint
|
# the existence of index file means it is a sharded checkpoint
|
||||||
ckpt_path = Path(checkpoint)
|
|
||||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||||
|
|
||||||
# return the origin model instead of the unwrapped model
|
# return the origin model instead of the unwrapped model
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
|
|
||||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||||
import re
|
|
||||||
|
|
||||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
@ -15,6 +17,7 @@ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
|||||||
# General helper functions
|
# General helper functions
|
||||||
# ======================================
|
# ======================================
|
||||||
|
|
||||||
|
|
||||||
def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
||||||
"""
|
"""
|
||||||
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
|
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
|
||||||
@ -28,6 +31,7 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
|||||||
"""
|
"""
|
||||||
return tensor.numel() * tensor.element_size() / 1024 / 1024
|
return tensor.numel() * tensor.element_size() / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
def is_safetensors_available() -> bool:
|
def is_safetensors_available() -> bool:
|
||||||
"""
|
"""
|
||||||
Check whether safetensors is available.
|
Check whether safetensors is available.
|
||||||
@ -78,7 +82,6 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
|||||||
# Helper functions for saving shard file
|
# Helper functions for saving shard file
|
||||||
# ======================================
|
# ======================================
|
||||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
given size.
|
given size.
|
||||||
@ -107,26 +110,30 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
|
|||||||
yield current_block, current_block_size
|
yield current_block, current_block_size
|
||||||
|
|
||||||
|
|
||||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
|
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
load shard state dict into model
|
load shard state dict into model
|
||||||
"""
|
"""
|
||||||
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
|
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
|
||||||
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
|
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
|
||||||
if use_safetensors:
|
if use_safetensors:
|
||||||
from safetensors.torch import safe_open
|
|
||||||
from safetensors.torch import load_file as safe_load_file
|
from safetensors.torch import load_file as safe_load_file
|
||||||
|
from safetensors.torch import safe_open
|
||||||
with safe_open(checkpoint_file, framework="pt") as f:
|
with safe_open(checkpoint_file, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
if metadata["format"] != "pt":
|
if metadata["format"] != "pt":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
|
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
|
||||||
)
|
|
||||||
return safe_load_file(checkpoint_file)
|
return safe_load_file(checkpoint_file)
|
||||||
else:
|
else:
|
||||||
return torch.load(checkpoint_file)
|
return torch.load(checkpoint_file)
|
||||||
|
|
||||||
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
|
|
||||||
|
def load_state_dict_into_model(model: nn.Module,
|
||||||
|
state_dict: torch.Tensor,
|
||||||
|
missing_keys: List,
|
||||||
|
strict: bool = False,
|
||||||
|
load_sub_module: bool = True):
|
||||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||||
this module and its descendants.
|
this module and its descendants.
|
||||||
|
|
||||||
@ -166,11 +173,12 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
|
|||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
|
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||||
', '.join('"{}"'.format(k) for k in unexpected_keys))
|
'"{}"'.format(k) for k in unexpected_keys))
|
||||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
|
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
# Helper functions for saving state dict
|
# Helper functions for saving state dict
|
||||||
# ======================================
|
# ======================================
|
||||||
@ -350,6 +358,8 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
|
|||||||
return True, index_files[0]
|
return True, index_files[0]
|
||||||
else:
|
else:
|
||||||
return False, None
|
return False, None
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(checkpoint_file_path: Path):
|
def load_state_dict(checkpoint_file_path: Path):
|
||||||
@ -382,7 +392,6 @@ def load_state_dict(checkpoint_file_path: Path):
|
|||||||
return torch.load(checkpoint_file_path)
|
return torch.load(checkpoint_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||||
if variant is not None and len(variant) > 0:
|
if variant is not None and len(variant) > 0:
|
||||||
splits = weights_name.split(".")
|
splits = weights_name.split(".")
|
||||||
@ -392,7 +401,7 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|||||||
return weights_name
|
return weights_name
|
||||||
|
|
||||||
|
|
||||||
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
|
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
generate base weight filenames
|
generate base weight filenames
|
||||||
"""
|
"""
|
||||||
@ -404,6 +413,7 @@ def get_base_filenames(variant: str=None, use_safetensors: bool=False):
|
|||||||
|
|
||||||
return weights_name, save_index_file
|
return weights_name, save_index_file
|
||||||
|
|
||||||
|
|
||||||
def get_shard_filename(weights_name: str, idx: int):
|
def get_shard_filename(weights_name: str, idx: int):
|
||||||
"""
|
"""
|
||||||
get shard file name
|
get shard file name
|
||||||
|
@ -11,9 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
# These models are not compatible with AMP
|
# These models are not compatible with AMP
|
||||||
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`']
|
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn']
|
||||||
# These models have no parameters
|
# These models have no parameters
|
||||||
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
|
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch']
|
||||||
# These models will get stuck
|
# These models will get stuck
|
||||||
_STUCK_MODELS = [
|
_STUCK_MODELS = [
|
||||||
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
|
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
|
||||||
@ -67,6 +67,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
|||||||
skipped_models.append(name)
|
skipped_models.append(name)
|
||||||
continue
|
continue
|
||||||
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
|
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if err is None:
|
if err is None:
|
||||||
@ -91,7 +92,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
|
|||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_low_level_zero_plugin(early_stop: bool = True):
|
def test_low_level_zero_plugin(early_stop: bool = True):
|
||||||
spawn(run_dist, 2, early_stop=early_stop)
|
spawn(run_dist, 4, early_stop=early_stop)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import SGD
|
from torch.optim import SGD
|
||||||
from torchvision.models import resnet18
|
from torchvision.models import resnet18
|
||||||
@ -8,12 +9,12 @@ from torchvision.models import resnet18
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import TorchDDPPlugin
|
from colossalai.booster.plugin import TorchDDPPlugin
|
||||||
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPCheckpointIO
|
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.testing import check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def check_torch_ddp_checkpointIO():
|
@parameterize('shard', [True, False])
|
||||||
|
def check_torch_ddp_checkpointIO(shard: bool):
|
||||||
plugin = TorchDDPPlugin()
|
plugin = TorchDDPPlugin()
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
model = resnet18()
|
model = resnet18()
|
||||||
@ -34,24 +35,39 @@ def check_torch_ddp_checkpointIO():
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
obj = [tempdir]
|
||||||
ckpt_io = TorchDDPCheckpointIO()
|
dist.broadcast_object_list(obj, src=0)
|
||||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
|
tempdir = obj[0] # use the same directory on all ranks
|
||||||
ckpt_io.save_lr_scheduler(scheduler, lr_scheduler_ckpt_tempfile.name)
|
|
||||||
|
model_ckpt_path = f"{tempdir}/model"
|
||||||
|
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||||
|
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
||||||
|
booster.save_model(model, model_ckpt_path, shard=shard)
|
||||||
|
if not shard:
|
||||||
|
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
|
||||||
|
booster.save_optimizer(optimizer, optimizer_ckpt_path)
|
||||||
|
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
new_model = resnet18()
|
new_model = resnet18()
|
||||||
new_optimizer = SGD((new_model.parameters()), lr=0.001)
|
new_optimizer = SGD((new_model.parameters()), lr=0.001)
|
||||||
new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1)
|
new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1)
|
||||||
_, new_optimizer, _, _, new_scheduler = booster.boost(new_model, new_optimizer, lr_scheduler=new_scheduler)
|
new_model, new_optimizer, _, _, new_scheduler = booster.boost(new_model,
|
||||||
|
new_optimizer,
|
||||||
|
lr_scheduler=new_scheduler)
|
||||||
|
|
||||||
if ckpt_io.coordinator.is_master():
|
booster.load_model(new_model, model_ckpt_path)
|
||||||
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||||
|
|
||||||
|
if not shard:
|
||||||
|
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||||
|
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||||
ckpt_io.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_tempfile.name)
|
|
||||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')
|
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
|
Loading…
Reference in New Issue
Block a user