mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[pipeline,shardformer] Fix p2p efficiency in pipeline, allow skipping loading weight not in weight_map when strict=False
, fix llama flash attention forward, add flop estimation by megatron in llama benchmark (#5017)
* Use p2p * Cannot bidirectonal send p2p * Refactor tensor creation and serialization in P2P communication * Fix llama forward args in flash attention * Add flop estimate from megatron * Support loading weight not in weight_map when strict=False in hybrid_parallel * Use send_forward_recv_backward, etc in 1f1b * Use dataclass for metdata Remove torch.cuda.synchronize() as suggested * Add comment about the torch.cuda.synchronize for potential error * Typo * Update hybrid_parallel_checkpoint_io.py * Update p2p.py * Update one_f_one_b.py * Update p2p.py --------- Co-authored-by: flybird11111 <1829166702@qq.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from functools import reduce
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -313,9 +314,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
|
||||
missing_keys = []
|
||||
missing_file_keys = []
|
||||
|
||||
def _load(name: str):
|
||||
if name not in weight_map:
|
||||
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
||||
missing_file_keys.append(name)
|
||||
return
|
||||
filename = weight_map[name]
|
||||
|
||||
# If this param/buffer has been loaded before, directly return.
|
||||
@@ -324,7 +329,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
missing_keys = []
|
||||
|
||||
load_state_dict_into_model(
|
||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||
@@ -357,6 +361,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
if len(missing_keys) == 0:
|
||||
raise RuntimeError(
|
||||
"No weigth is loaded into the model. Please check the checkpoint files and the model structure."
|
||||
)
|
||||
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
remain_keys = remain_keys.union(set(missing_file_keys))
|
||||
if len(remain_keys) > 0:
|
||||
if strict:
|
||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.coordinator.is_master():
|
||||
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
|
Reference in New Issue
Block a user