[plugin] torch ddp plugin supports sharded model checkpoint (#3775)

* [plugin] torch ddp plugin add save sharded model

* [test] fix torch ddp ckpt io test

* [test] fix torch ddp ckpt io test

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] remove debug info
This commit is contained in:
Hongxin Liu
2023-05-18 20:05:59 +08:00
committed by GitHub
parent 2703a37ac9
commit 5452df63c5
5 changed files with 86 additions and 51 deletions

View File

@@ -1,10 +1,12 @@
# coding=utf-8
import re
from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch
import torch.nn as nn
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
from colossalai.tensor.d_tensor.d_tensor import DTensor
import re
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@@ -15,6 +17,7 @@ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
# General helper functions
# ======================================
def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
@@ -28,6 +31,7 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
return tensor.numel() * tensor.element_size() / 1024 / 1024
def is_safetensors_available() -> bool:
"""
Check whether safetensors is available.
@@ -78,7 +82,6 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@@ -100,35 +103,39 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
if ret_block != None:
yield ret_block, ret_block_size
yield current_block, current_block_size
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
"""
load shard state dict into model
"""
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file)
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
def load_state_dict_into_model(model: nn.Module,
state_dict: torch.Tensor,
missing_keys: List,
strict: bool = False,
load_sub_module: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
this module and its descendants.
Args:
state_dict (dict): a dict containing parameters and
@@ -166,11 +173,12 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
if strict:
if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys))
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in unexpected_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
model.__class__.__name__, "\n\t".join(error_msgs)))
# ======================================
# Helper functions for saving state dict
# ======================================
@@ -350,6 +358,8 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return True, index_files[0]
else:
return False, None
else:
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
def load_state_dict(checkpoint_file_path: Path):
@@ -380,7 +390,6 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
return torch.load(checkpoint_file_path)
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
@@ -392,17 +401,18 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)
return weights_name, save_index_file
return weights_name, save_index_file
def get_shard_filename(weights_name: str, idx: int):
"""
@@ -410,4 +420,4 @@ def get_shard_filename(weights_name: str, idx: int):
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
return shard_file
return shard_file