mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[shardformer] support from_pretrained when loading model with HybridParallelPlugin (#4575)
* hybrid plugin support huggingface from_pretrained * add huggingface compatibility tests * add folder cleaning * fix bugs
This commit is contained in:
@@ -26,6 +26,7 @@ from .utils import (
|
||||
load_shard_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict_shards,
|
||||
search_tp_partition_dim,
|
||||
@@ -204,6 +205,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
@@ -219,9 +221,9 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
|
||||
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
|
||||
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
||||
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
@@ -229,7 +231,8 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors)
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True)
|
||||
if control_saving:
|
||||
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
@@ -251,6 +254,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
final_index_file.append_weight_map(weight, weight_filename)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
save_config_file(model, checkpoint)
|
||||
rmtree(tmp_index_file_folder)
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
@@ -423,15 +427,16 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving)
|
||||
is_master=control_saving,
|
||||
use_pp_format=True)
|
||||
|
||||
if control_saving:
|
||||
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
|
||||
|
Reference in New Issue
Block a user