[pipeline,shardformer] Fix p2p efficiency in pipeline, allow skipping loading weight not in weight_map when strict=False, fix llama flash attention forward, add flop estimation by megatron in llama benchmark (#5017)

* Use p2p

* Cannot bidirectonal send p2p

* Refactor tensor creation and serialization in P2P
communication

* Fix llama forward args in flash attention

* Add flop estimate from megatron

* Support loading weight not in weight_map when strict=False in hybrid_parallel

* Use send_forward_recv_backward, etc in 1f1b

* Use dataclass for metdata
Remove torch.cuda.synchronize() as suggested

* Add comment about the torch.cuda.synchronize for potential error

* Typo

* Update hybrid_parallel_checkpoint_io.py

* Update p2p.py

* Update one_f_one_b.py

* Update p2p.py

---------

Co-authored-by: flybird11111 <1829166702@qq.com>
This commit is contained in:
Elsa Granger
2023-11-16 20:15:59 +08:00
committed by GitHub
parent 28052a71fb
commit b2ad0d9e8f
6 changed files with 415 additions and 14 deletions

View File

@@ -1,4 +1,5 @@
import copy
from functools import reduce
import logging
import os
from pathlib import Path
@@ -313,9 +314,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
missing_keys = []
missing_file_keys = []
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
missing_file_keys.append(name)
return
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
@@ -324,7 +329,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = []
load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
@@ -357,6 +361,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.verbose and self.coordinator.is_master():
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
if len(missing_keys) == 0:
raise RuntimeError(
"No weigth is loaded into the model. Please check the checkpoint files and the model structure."
)
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0:
if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
else:
if self.coordinator.is_master():
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,