[shardformer] support from_pretrained when loading model with HybridParallelPlugin (#4575)

* hybrid plugin support huggingface from_pretrained

* add huggingface compatibility tests

* add folder cleaning

* fix bugs
This commit is contained in:
Baizhou Zhang
2023-09-01 17:40:01 +08:00
committed by GitHub
parent c9625dbb63
commit 38ccb8b1a3
5 changed files with 218 additions and 17 deletions

View File

@@ -26,6 +26,7 @@ from .utils import (
load_shard_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict_shards,
search_tp_partition_dim,
@@ -204,6 +205,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
@@ -219,9 +221,9 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
@@ -229,7 +231,8 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors)
use_safetensors=use_safetensors,
use_pp_format=True)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
@@ -251,6 +254,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
final_index_file.append_weight_map(weight, weight_filename)
final_index_file.write_index_file(final_index_file_path)
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
@@ -423,15 +427,16 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving)
is_master=control_saving,
use_pp_format=True)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."