mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[pipeline,shardformer] Fix p2p efficiency in pipeline, allow skipping loading weight not in weight_map when strict=False
, fix llama flash attention forward, add flop estimation by megatron in llama benchmark (#5017)
* Use p2p * Cannot bidirectonal send p2p * Refactor tensor creation and serialization in P2P communication * Fix llama forward args in flash attention * Add flop estimate from megatron * Support loading weight not in weight_map when strict=False in hybrid_parallel * Use send_forward_recv_backward, etc in 1f1b * Use dataclass for metdata Remove torch.cuda.synchronize() as suggested * Add comment about the torch.cuda.synchronize for potential error * Typo * Update hybrid_parallel_checkpoint_io.py * Update p2p.py * Update one_f_one_b.py * Update p2p.py --------- Co-authored-by: flybird11111 <1829166702@qq.com>
This commit is contained in:
@@ -413,6 +413,7 @@ def get_llama_flash_attention_forward():
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
Reference in New Issue
Block a user