mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[shardformer] support from_pretrained when loading model with HybridParallelPlugin (#4575)
* hybrid plugin support huggingface from_pretrained * add huggingface compatibility tests * add folder cleaning * fix bugs
This commit is contained in:
@@ -9,12 +9,12 @@ from pathlib import Path
|
||||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype
|
||||
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
@@ -228,7 +228,8 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
||||
index_file: "CheckpointIndexFile",
|
||||
base_filename: str,
|
||||
is_master: bool,
|
||||
use_safetensors: bool = False) -> int:
|
||||
use_safetensors: bool = False,
|
||||
use_pp_format: bool = False) -> int:
|
||||
'''
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
Args:
|
||||
@@ -236,14 +237,16 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
||||
checkpoint (str): The path of checkpoint directory as string.
|
||||
index_file (CheckpointIndexFile): The index file object to be updated.
|
||||
base_filename (str): Decides the prefix of filenames of shards.
|
||||
is_master (bool): Whether current rank is master.
|
||||
use_safetensors (bool): Whether to use safetensors to save checkpoint.
|
||||
is_master (bool): Whether current rank is main process.
|
||||
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
|
||||
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
|
||||
|
||||
Returns:
|
||||
int: the total size of shards
|
||||
'''
|
||||
|
||||
total_size = 0
|
||||
shard_filenames = []
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
shard, current_size = shard_pair
|
||||
if not is_master:
|
||||
@@ -257,8 +260,12 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
||||
|
||||
# Only save on master rank.
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
|
||||
shard_filenames.append(shard_file)
|
||||
del shard
|
||||
|
||||
# Clean folder, deleted unneeded files.
|
||||
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
@@ -335,6 +342,66 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None:
|
||||
torch.save(param_groups, group_file_path)
|
||||
|
||||
|
||||
def clean_folder(checkpoint_path: str,
|
||||
weights_name: str,
|
||||
shard_filenames: List[str],
|
||||
is_master: bool = True,
|
||||
use_pp_format: bool = False):
|
||||
"""
|
||||
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): Path to the checkpoint directory.
|
||||
weights_name (str): Decides the prefix of filenames of weight shards.
|
||||
shard_filenames (List[str]): The list of saved shard filenames which should not be removed.
|
||||
is_master (bool, optional): Whether current rank is main process. Defaults to True.
|
||||
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
|
||||
|
||||
"""
|
||||
if is_master:
|
||||
for filename in os.listdir(checkpoint_path):
|
||||
full_filename = os.path.join(checkpoint_path, filename)
|
||||
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
||||
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
if not use_pp_format:
|
||||
reg = re.compile(r"(.*?)-\d{5}")
|
||||
else:
|
||||
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
|
||||
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
|
||||
if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
|
||||
and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
|
||||
os.remove(full_filename)
|
||||
|
||||
|
||||
def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True):
|
||||
"""
|
||||
Save config.json/generation_config.json if model is a Huggingface pretrained model.
|
||||
This method can only be called when a model is saved in a sharded way.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model whose config should be saved if it's a huggingface model.
|
||||
checkpoint_path (str): Path to the checkpoint directory.
|
||||
is_master (bool): Whether current rank is main process.
|
||||
"""
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
return
|
||||
|
||||
model = unwrap_huggingface_model(model)
|
||||
|
||||
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
|
||||
dtype = get_parameter_dtype(model)
|
||||
model.config.torch_dtype = str(dtype).split(".")[1]
|
||||
|
||||
# Attach architecture to the config
|
||||
model.config.architectures = [model.__class__.__name__]
|
||||
|
||||
# Save the config
|
||||
if is_master:
|
||||
model.config.save_pretrained(checkpoint_path)
|
||||
if model.can_generate():
|
||||
model.generation_config.save_pretrained(checkpoint_path)
|
||||
|
||||
|
||||
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
|
||||
"""
|
||||
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
|
||||
@@ -709,5 +776,5 @@ def get_shard_filename(weights_name: str, idx: int):
|
||||
get shard file name
|
||||
"""
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
|
||||
return shard_file
|
||||
|
Reference in New Issue
Block a user