mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[plugin] torch ddp plugin supports sharded model checkpoint (#3775)
* [plugin] torch ddp plugin add save sharded model * [test] fix torch ddp ckpt io test * [test] fix torch ddp ckpt io test * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] remove debug info
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
# coding=utf-8
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
|
||||
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
import re
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
@@ -15,6 +17,7 @@ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
# General helper functions
|
||||
# ======================================
|
||||
|
||||
|
||||
def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
||||
"""
|
||||
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
|
||||
@@ -28,6 +31,7 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
||||
"""
|
||||
return tensor.numel() * tensor.element_size() / 1024 / 1024
|
||||
|
||||
|
||||
def is_safetensors_available() -> bool:
|
||||
"""
|
||||
Check whether safetensors is available.
|
||||
@@ -78,7 +82,6 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||
# Helper functions for saving shard file
|
||||
# ======================================
|
||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
@@ -100,35 +103,39 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
|
||||
current_block_size = 0
|
||||
current_block[key] = weight
|
||||
current_block_size += weight_size
|
||||
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
yield current_block, current_block_size
|
||||
|
||||
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||
"""
|
||||
load shard state dict into model
|
||||
"""
|
||||
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
|
||||
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
|
||||
if use_safetensors:
|
||||
from safetensors.torch import safe_open
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import safe_open
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata["format"] != "pt":
|
||||
raise NotImplementedError(
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
|
||||
)
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
|
||||
return safe_load_file(checkpoint_file)
|
||||
else:
|
||||
return torch.load(checkpoint_file)
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
|
||||
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module,
|
||||
state_dict: torch.Tensor,
|
||||
missing_keys: List,
|
||||
strict: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants.
|
||||
this module and its descendants.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
@@ -166,11 +173,12 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in unexpected_keys))
|
||||
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in unexpected_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
# ======================================
|
||||
# Helper functions for saving state dict
|
||||
# ======================================
|
||||
@@ -350,6 +358,8 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
|
||||
return True, index_files[0]
|
||||
else:
|
||||
return False, None
|
||||
else:
|
||||
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file_path: Path):
|
||||
@@ -380,7 +390,6 @@ def load_state_dict(checkpoint_file_path: Path):
|
||||
else:
|
||||
# load with torch
|
||||
return torch.load(checkpoint_file_path)
|
||||
|
||||
|
||||
|
||||
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
@@ -392,17 +401,18 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
return weights_name
|
||||
|
||||
|
||||
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
|
||||
"""
|
||||
generate base weight filenames
|
||||
"""
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
|
||||
"""
|
||||
generate base weight filenames
|
||||
"""
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
save_index_file = add_variant(save_index_file, variant)
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
save_index_file = add_variant(save_index_file, variant)
|
||||
|
||||
return weights_name, save_index_file
|
||||
|
||||
return weights_name, save_index_file
|
||||
|
||||
def get_shard_filename(weights_name: str, idx: int):
|
||||
"""
|
||||
@@ -410,4 +420,4 @@ def get_shard_filename(weights_name: str, idx: int):
|
||||
"""
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
|
||||
return shard_file
|
||||
return shard_file
|
||||
|
Reference in New Issue
Block a user